Skip to content

Commit b5f8758

Browse files
committed
Merge util._async_validate_and_get_session_maker_for_db_url
Signed-off-by: David Rapan <[email protected]>
1 parent 46e00ce commit b5f8758

File tree

3 files changed

+37
-47
lines changed

3 files changed

+37
-47
lines changed

homeassistant/components/sql/util.py

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
_LOGGER = logging.getLogger(__name__)
3232

33+
_SQL_SELECT = sqlalchemy.text("SELECT 1;")
3334
_SQL_LAMBDA_CACHE: LRUCache = LRUCache(1000)
3435

3536

@@ -107,11 +108,8 @@ async def async_create_sessionmaker(
107108
# for every sensor.
108109
elif db_url in sql_data.session_makers_by_db_url:
109110
sessmaker = sql_data.session_makers_by_db_url[db_url]
110-
elif "+aiomysql" in db_url or "+aiosqlite" in db_url or "+asyncpg" in db_url:
111-
if sessmaker := await _async_validate_and_get_session_maker_for_db_url(db_url):
112-
sql_data.session_makers_by_db_url[db_url] = sessmaker
113-
elif sessmaker := await hass.async_add_executor_job(
114-
_validate_and_get_session_maker_for_db_url, db_url
111+
elif sessmaker := await _async_validate_and_get_session_maker_for_db_url(
112+
hass, db_url
115113
):
116114
sql_data.session_makers_by_db_url[db_url] = sessmaker
117115
else:
@@ -210,57 +208,46 @@ async def _shutdown_db_engines(_: Event) -> None:
210208

211209

212210
async def _async_validate_and_get_session_maker_for_db_url(
213-
db_url: str,
214-
) -> async_scoped_session[AsyncSession] | None:
211+
hass: HomeAssistant, db_url: str
212+
) -> async_scoped_session[AsyncSession] | scoped_session[Session] | None:
215213
"""Validate the db_url and return a async session maker."""
216214
try:
217-
maker = async_scoped_session(
218-
async_sessionmaker(
219-
bind=create_async_engine(db_url, future=True), future=True
220-
),
221-
scopefunc=asyncio.current_task,
222-
)
223-
# Run a dummy query just to test the db_url
224-
async with maker() as session:
225-
await session.execute(sqlalchemy.text("SELECT 1;"))
226-
227-
except SQLAlchemyError as err:
228-
_LOGGER.error(
229-
"Couldn't connect using %s DB_URL: %s",
230-
redact_credentials(db_url),
231-
redact_credentials(str(err)),
232-
)
233-
return None
234-
else:
235-
return maker
236-
237-
238-
def _validate_and_get_session_maker_for_db_url(
239-
db_url: str,
240-
) -> scoped_session[Session] | None:
241-
"""Validate the db_url and return a session maker.
242-
243-
This does I/O and should be run in the executor.
244-
"""
245-
try:
246-
maker = scoped_session(
247-
sessionmaker(
248-
bind=sqlalchemy.create_engine(db_url, future=True), future=True
215+
if "+aiomysql" in db_url or "+aiosqlite" in db_url or "+asyncpg" in db_url:
216+
maker = async_scoped_session(
217+
async_sessionmaker(
218+
bind=create_async_engine(db_url, future=True), future=True
219+
),
220+
scopefunc=asyncio.current_task,
249221
)
250-
)
251-
# Run a dummy query just to test the db_url
252-
with maker() as session:
253-
session.execute(sqlalchemy.text("SELECT 1;"))
222+
# Run a dummy query just to test the db_url
223+
async with maker() as session:
224+
await session.execute(_SQL_SELECT)
225+
return maker
226+
227+
def _get_session_maker_for_db_url() -> scoped_session[Session] | None:
228+
"""Validate the db_url and return a session maker.
229+
230+
This does I/O and should be run in the executor.
231+
"""
232+
maker = scoped_session(
233+
sessionmaker(
234+
bind=sqlalchemy.create_engine(db_url, future=True), future=True
235+
)
236+
)
237+
# Run a dummy query just to test the db_url
238+
with maker() as session:
239+
session.execute(_SQL_SELECT)
240+
return maker
254241

242+
return await hass.async_add_executor_job(_get_session_maker_for_db_url)
255243
except SQLAlchemyError as err:
256244
_LOGGER.error(
257245
"Couldn't connect using %s DB_URL: %s",
258246
redact_credentials(db_url),
259247
redact_credentials(str(err)),
260248
)
261-
return None
262-
else:
263-
return maker
249+
250+
return None
264251

265252

266253
def generate_lambda_stmt(query: str) -> StatementLambdaElement:

tests/components/sql/test_sensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ def execute(self, query: Any) -> None:
325325
"""Execute the query."""
326326
raise SQLAlchemyError("sqlite://homeassistant:[email protected]")
327327

328+
def rollback(self) -> None:
329+
pass
330+
328331
with patch(
329332
"homeassistant.components.sql.util.scoped_session",
330333
return_value=MockSession,

tests/components/sql/test_services.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ async def test_query_service_invalid_db_url(hass: HomeAssistant) -> None:
230230

231231
with (
232232
patch(
233-
"homeassistant.components.sql.util._validate_and_get_session_maker_for_db_url",
233+
"homeassistant.components.sql.util._async_validate_and_get_session_maker_for_db_url",
234234
return_value=None,
235235
),
236236
pytest.raises(

0 commit comments

Comments
 (0)