|
30 | 30 |
|
31 | 31 | _LOGGER = logging.getLogger(__name__) |
32 | 32 |
|
| 33 | +_SQL_SELECT = sqlalchemy.text("SELECT 1;") |
33 | 34 | _SQL_LAMBDA_CACHE: LRUCache = LRUCache(1000) |
34 | 35 |
|
35 | 36 |
|
@@ -107,11 +108,8 @@ async def async_create_sessionmaker( |
107 | 108 | # for every sensor. |
108 | 109 | elif db_url in sql_data.session_makers_by_db_url: |
109 | 110 | sessmaker = sql_data.session_makers_by_db_url[db_url] |
110 | | - elif "+aiomysql" in db_url or "+aiosqlite" in db_url or "+asyncpg" in db_url: |
111 | | - if sessmaker := await _async_validate_and_get_session_maker_for_db_url(db_url): |
112 | | - sql_data.session_makers_by_db_url[db_url] = sessmaker |
113 | | - elif sessmaker := await hass.async_add_executor_job( |
114 | | - _validate_and_get_session_maker_for_db_url, db_url |
| 111 | + elif sessmaker := await _async_validate_and_get_session_maker_for_db_url( |
| 112 | + hass, db_url |
115 | 113 | ): |
116 | 114 | sql_data.session_makers_by_db_url[db_url] = sessmaker |
117 | 115 | else: |
@@ -210,57 +208,46 @@ async def _shutdown_db_engines(_: Event) -> None: |
210 | 208 |
|
211 | 209 |
|
212 | 210 | async def _async_validate_and_get_session_maker_for_db_url( |
213 | | - db_url: str, |
214 | | -) -> async_scoped_session[AsyncSession] | None: |
| 211 | + hass: HomeAssistant, db_url: str |
| 212 | +) -> async_scoped_session[AsyncSession] | scoped_session[Session] | None: |
215 | 213 | """Validate the db_url and return a async session maker.""" |
216 | 214 | try: |
217 | | - maker = async_scoped_session( |
218 | | - async_sessionmaker( |
219 | | - bind=create_async_engine(db_url, future=True), future=True |
220 | | - ), |
221 | | - scopefunc=asyncio.current_task, |
222 | | - ) |
223 | | - # Run a dummy query just to test the db_url |
224 | | - async with maker() as session: |
225 | | - await session.execute(sqlalchemy.text("SELECT 1;")) |
226 | | - |
227 | | - except SQLAlchemyError as err: |
228 | | - _LOGGER.error( |
229 | | - "Couldn't connect using %s DB_URL: %s", |
230 | | - redact_credentials(db_url), |
231 | | - redact_credentials(str(err)), |
232 | | - ) |
233 | | - return None |
234 | | - else: |
235 | | - return maker |
236 | | - |
237 | | - |
238 | | -def _validate_and_get_session_maker_for_db_url( |
239 | | - db_url: str, |
240 | | -) -> scoped_session[Session] | None: |
241 | | - """Validate the db_url and return a session maker. |
242 | | -
|
243 | | - This does I/O and should be run in the executor. |
244 | | - """ |
245 | | - try: |
246 | | - maker = scoped_session( |
247 | | - sessionmaker( |
248 | | - bind=sqlalchemy.create_engine(db_url, future=True), future=True |
| 215 | + if "+aiomysql" in db_url or "+aiosqlite" in db_url or "+asyncpg" in db_url: |
| 216 | + maker = async_scoped_session( |
| 217 | + async_sessionmaker( |
| 218 | + bind=create_async_engine(db_url, future=True), future=True |
| 219 | + ), |
| 220 | + scopefunc=asyncio.current_task, |
249 | 221 | ) |
250 | | - ) |
251 | | - # Run a dummy query just to test the db_url |
252 | | - with maker() as session: |
253 | | - session.execute(sqlalchemy.text("SELECT 1;")) |
| 222 | + # Run a dummy query just to test the db_url |
| 223 | + async with maker() as session: |
| 224 | + await session.execute(_SQL_SELECT) |
| 225 | + return maker |
| 226 | + |
| 227 | + def _get_session_maker_for_db_url() -> scoped_session[Session] | None: |
| 228 | + """Validate the db_url and return a session maker. |
| 229 | +
|
| 230 | + This does I/O and should be run in the executor. |
| 231 | + """ |
| 232 | + maker = scoped_session( |
| 233 | + sessionmaker( |
| 234 | + bind=sqlalchemy.create_engine(db_url, future=True), future=True |
| 235 | + ) |
| 236 | + ) |
| 237 | + # Run a dummy query just to test the db_url |
| 238 | + with maker() as session: |
| 239 | + session.execute(_SQL_SELECT) |
| 240 | + return maker |
254 | 241 |
|
| 242 | + return await hass.async_add_executor_job(_get_session_maker_for_db_url) |
255 | 243 | except SQLAlchemyError as err: |
256 | 244 | _LOGGER.error( |
257 | 245 | "Couldn't connect using %s DB_URL: %s", |
258 | 246 | redact_credentials(db_url), |
259 | 247 | redact_credentials(str(err)), |
260 | 248 | ) |
261 | | - return None |
262 | | - else: |
263 | | - return maker |
| 249 | + |
| 250 | + return None |
264 | 251 |
|
265 | 252 |
|
266 | 253 | def generate_lambda_stmt(query: str) -> StatementLambdaElement: |
|
0 commit comments