22
33from __future__ import annotations
44
5+ import asyncio
56import logging
67
78import sqlalchemy
89from sqlalchemy import lambda_stmt
910from sqlalchemy .exc import SQLAlchemyError
11+ from sqlalchemy .ext .asyncio import (
12+ AsyncSession ,
13+ async_scoped_session ,
14+ async_sessionmaker ,
15+ create_async_engine ,
16+ )
1017from sqlalchemy .orm import Session , scoped_session , sessionmaker
1118from sqlalchemy .sql .lambdas import StatementLambdaElement
1219from sqlalchemy .util import LRUCache
@@ -55,7 +62,9 @@ def validate_sql_select(value: str) -> str:
5562
5663async def async_create_sessionmaker (
5764 hass : HomeAssistant , db_url : str
58- ) -> tuple [scoped_session | None , bool , bool ]:
65+ ) -> tuple [
66+ async_scoped_session [AsyncSession ] | scoped_session [Session ] | None , bool , bool
67+ ]:
5968 """Create a session maker for the given db_url.
6069
6170 This function gets or creates a SQLAlchemy `scoped_session` for the given
@@ -83,7 +92,7 @@ async def async_create_sessionmaker(
8392 uses_recorder_db = False
8493 else :
8594 uses_recorder_db = db_url == instance .db_url
86- sessmaker : scoped_session | None
95+ sessmaker : async_scoped_session [ AsyncSession ] | scoped_session [ Session ] | None
8796 sql_data = _async_get_or_init_domain_data (hass )
8897 use_database_executor = False
8998 if uses_recorder_db and instance .dialect_name == SupportedDialect .SQLITE :
@@ -98,6 +107,9 @@ async def async_create_sessionmaker(
98107 # for every sensor.
99108 elif db_url in sql_data .session_makers_by_db_url :
100109 sessmaker = sql_data .session_makers_by_db_url [db_url ]
110+ elif "+aiomysql" in db_url or "+aiosqlite" in db_url or "+asyncpg" in db_url :
111+ if sessmaker := await _async_validate_and_get_session_maker_for_db_url (db_url ):
112+ sql_data .session_makers_by_db_url [db_url ] = sessmaker
101113 elif sessmaker := await hass .async_add_executor_job (
102114 _validate_and_get_session_maker_for_db_url , db_url
103115 ):
@@ -169,7 +181,9 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData:
169181 sql_data : SQLData = hass .data [DOMAIN ]
170182 return sql_data
171183
172- session_makers_by_db_url : dict [str , scoped_session ] = {}
184+ session_makers_by_db_url : dict [
185+ str , async_scoped_session [AsyncSession ] | scoped_session [Session ]
186+ ] = {}
173187
174188 #
175189 # Ensure we dispose of all engines at shutdown
@@ -178,10 +192,13 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData:
178192 # Shutdown all sessions in the executor since they will
179193 # do blocking I/O
180194 #
181- def _shutdown_db_engines (event : Event ) -> None :
195+ async def _shutdown_db_engines (_ : Event ) -> None :
182196 """Shutdown all database engines."""
183197 for sessmaker in session_makers_by_db_url .values ():
184- sessmaker .connection ().engine .dispose ()
198+ if isinstance (sessmaker , async_scoped_session ):
199+ await (await sessmaker .connection ()).engine .dispose ()
200+ else :
201+ sessmaker .connection ().engine .dispose ()
185202
186203 cancel_shutdown = hass .bus .async_listen_once (
187204 EVENT_HOMEASSISTANT_STOP , _shutdown_db_engines
@@ -192,18 +209,46 @@ def _shutdown_db_engines(event: Event) -> None:
192209 return sql_data
193210
194211
195- def _validate_and_get_session_maker_for_db_url (db_url : str ) -> scoped_session | None :
212+ async def _async_validate_and_get_session_maker_for_db_url (
213+ db_url : str ,
214+ ) -> async_scoped_session [AsyncSession ] | None :
215+ """Validate the db_url and return a async session maker."""
216+ try :
217+ maker = async_scoped_session (
218+ async_sessionmaker (bind = create_async_engine (db_url )),
219+ scopefunc = asyncio .current_task ,
220+ )
221+ # Run a dummy query just to test the db_url
222+ async with maker () as session :
223+ await session .execute (sqlalchemy .text ("SELECT 1;" ))
224+
225+ except SQLAlchemyError as err :
226+ _LOGGER .error (
227+ "Couldn't connect using %s DB_URL: %s" ,
228+ redact_credentials (db_url ),
229+ redact_credentials (str (err )),
230+ )
231+ return None
232+ else :
233+ return maker
234+
235+
236+ def _validate_and_get_session_maker_for_db_url (
237+ db_url : str ,
238+ ) -> scoped_session [Session ] | None :
196239 """Validate the db_url and return a session maker.
197240
198241 This does I/O and should be run in the executor.
199242 """
200- sess : Session | None = None
201243 try :
202- engine = sqlalchemy .create_engine (db_url , future = True )
203- sessmaker = scoped_session (sessionmaker (bind = engine , future = True ))
244+ maker = scoped_session (
245+ sessionmaker (
246+ bind = sqlalchemy .create_engine (db_url , future = True ), future = True
247+ )
248+ )
204249 # Run a dummy query just to test the db_url
205- sess = sessmaker ()
206- sess .execute (sqlalchemy .text ("SELECT 1;" ))
250+ with maker () as session :
251+ session .execute (sqlalchemy .text ("SELECT 1;" ))
207252
208253 except SQLAlchemyError as err :
209254 _LOGGER .error (
@@ -213,10 +258,7 @@ def _validate_and_get_session_maker_for_db_url(db_url: str) -> scoped_session |
213258 )
214259 return None
215260 else :
216- return sessmaker
217- finally :
218- if sess :
219- sess .close ()
261+ return maker
220262
221263
222264def generate_lambda_stmt (query : str ) -> StatementLambdaElement :
0 commit comments