diff --git a/sqlspec/base.py b/sqlspec/base.py index 855432b17..b01efee26 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -1,4 +1,6 @@ # ruff: noqa: PLR6301 +import atexit +import contextlib import re from abc import ABC, abstractmethod from collections.abc import Awaitable @@ -19,6 +21,7 @@ from sqlspec.exceptions import NotFoundError from sqlspec.statement import SQLStatement from sqlspec.typing import ModelDTOT, StatementParameterType +from sqlspec.utils.sync_tools import maybe_async_ if TYPE_CHECKING: from contextlib import AbstractAsyncContextManager, AbstractContextManager @@ -202,6 +205,15 @@ class SQLSpec: def __init__(self) -> None: self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {} + # Register the cleanup handler to run at program exit + atexit.register(self._cleanup_pools) + + def _cleanup_pools(self) -> None: + """Clean up all open database pools at program exit.""" + for config in self._configs.values(): + if config.support_connection_pooling and config.pool_instance is not None: + with contextlib.suppress(Exception): + maybe_async_(config.close_pool)() @overload def add_config(self, config: "SyncConfigT") -> "type[SyncConfigT]": ...