Skip to content

Commit 900822f

Browse files
committed
Merge util._async_validate_and_get_session_maker_for_db_url
Signed-off-by: David Rapan <[email protected]>
1 parent 46e00ce commit 900822f

File tree

5 files changed

+41
-51
lines changed

5 files changed

+41
-51
lines changed

homeassistant/components/sql/sensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,12 @@ def _update(self) -> None:
289289
try:
290290
self._process(session.execute(self._lambda_stmt))
291291
except SQLAlchemyError as err:
292+
session.rollback()
292293
_LOGGER.error(
293294
"Error executing query %s: %s",
294295
self._query,
295296
redact_credentials(str(err)),
296297
)
297-
session.rollback()
298298

299299
async def async_update(self) -> None:
300300
"""Retrieve sensor data from the query using the right executor."""
@@ -304,12 +304,12 @@ async def async_update(self) -> None:
304304
try:
305305
self._process(await session.execute(self._lambda_stmt))
306306
except SQLAlchemyError as err:
307+
await session.rollback()
307308
_LOGGER.error(
308309
"Error executing query %s: %s",
309310
self._query,
310311
redact_credentials(str(err)),
311312
)
312-
await session.rollback()
313313
elif self._use_database_executor:
314314
await get_instance(self.hass).async_add_executor_job(self._update)
315315
else:

homeassistant/components/sql/services.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ def _execute_and_convert_query() -> list[JsonValueType]:
9696
try:
9797
return _process(session.execute(generate_lambda_stmt(query_str)))
9898
except SQLAlchemyError as err:
99+
session.rollback()
99100
_LOGGER.debug(
100101
"Error executing query %s: %s",
101102
query_str,
102103
redact_credentials(str(err)),
103104
)
104-
session.rollback()
105105
raise
106106

107107
try:
@@ -112,12 +112,12 @@ def _execute_and_convert_query() -> list[JsonValueType]:
112112
await session.execute(generate_lambda_stmt(query_str))
113113
)
114114
except SQLAlchemyError as err:
115+
await session.rollback()
115116
_LOGGER.debug(
116117
"Error executing query %s: %s",
117118
query_str,
118119
redact_credentials(str(err)),
119120
)
120-
await session.rollback()
121121
raise
122122
elif use_database_executor:
123123
result = await get_instance(call.hass).async_add_executor_job(

homeassistant/components/sql/util.py

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
_LOGGER = logging.getLogger(__name__)
3232

33+
_SQL_SELECT = sqlalchemy.text("SELECT 1;")
3334
_SQL_LAMBDA_CACHE: LRUCache = LRUCache(1000)
3435

3536

@@ -107,11 +108,8 @@ async def async_create_sessionmaker(
107108
# for every sensor.
108109
elif db_url in sql_data.session_makers_by_db_url:
109110
sessmaker = sql_data.session_makers_by_db_url[db_url]
110-
elif "+aiomysql" in db_url or "+aiosqlite" in db_url or "+asyncpg" in db_url:
111-
if sessmaker := await _async_validate_and_get_session_maker_for_db_url(db_url):
112-
sql_data.session_makers_by_db_url[db_url] = sessmaker
113-
elif sessmaker := await hass.async_add_executor_job(
114-
_validate_and_get_session_maker_for_db_url, db_url
111+
elif sessmaker := await _async_validate_and_get_session_maker_for_db_url(
112+
hass, db_url
115113
):
116114
sql_data.session_makers_by_db_url[db_url] = sessmaker
117115
else:
@@ -210,57 +208,46 @@ async def _shutdown_db_engines(_: Event) -> None:
210208

211209

212210
async def _async_validate_and_get_session_maker_for_db_url(
213-
db_url: str,
214-
) -> async_scoped_session[AsyncSession] | None:
211+
hass: HomeAssistant, db_url: str
212+
) -> async_scoped_session[AsyncSession] | scoped_session[Session] | None:
215213
"""Validate the db_url and return a async session maker."""
216214
try:
217-
maker = async_scoped_session(
218-
async_sessionmaker(
219-
bind=create_async_engine(db_url, future=True), future=True
220-
),
221-
scopefunc=asyncio.current_task,
222-
)
223-
# Run a dummy query just to test the db_url
224-
async with maker() as session:
225-
await session.execute(sqlalchemy.text("SELECT 1;"))
226-
227-
except SQLAlchemyError as err:
228-
_LOGGER.error(
229-
"Couldn't connect using %s DB_URL: %s",
230-
redact_credentials(db_url),
231-
redact_credentials(str(err)),
232-
)
233-
return None
234-
else:
235-
return maker
236-
237-
238-
def _validate_and_get_session_maker_for_db_url(
239-
db_url: str,
240-
) -> scoped_session[Session] | None:
241-
"""Validate the db_url and return a session maker.
242-
243-
This does I/O and should be run in the executor.
244-
"""
245-
try:
246-
maker = scoped_session(
247-
sessionmaker(
248-
bind=sqlalchemy.create_engine(db_url, future=True), future=True
215+
if "+aiomysql" in db_url or "+aiosqlite" in db_url or "+asyncpg" in db_url:
216+
maker = async_scoped_session(
217+
async_sessionmaker(
218+
bind=create_async_engine(db_url, future=True), future=True
219+
),
220+
scopefunc=asyncio.current_task,
249221
)
250-
)
251-
# Run a dummy query just to test the db_url
252-
with maker() as session:
253-
session.execute(sqlalchemy.text("SELECT 1;"))
222+
# Run a dummy query just to test the db_url
223+
async with maker() as session:
224+
await session.execute(_SQL_SELECT)
225+
return maker
226+
227+
def _get_session_maker_for_db_url() -> scoped_session[Session] | None:
228+
"""Validate the db_url and return a session maker.
229+
230+
This does I/O and should be run in the executor.
231+
"""
232+
maker = scoped_session(
233+
sessionmaker(
234+
bind=sqlalchemy.create_engine(db_url, future=True), future=True
235+
)
236+
)
237+
# Run a dummy query just to test the db_url
238+
with maker() as session:
239+
session.execute(_SQL_SELECT)
240+
return maker
254241

242+
return await hass.async_add_executor_job(_get_session_maker_for_db_url)
255243
except SQLAlchemyError as err:
256244
_LOGGER.error(
257245
"Couldn't connect using %s DB_URL: %s",
258246
redact_credentials(db_url),
259247
redact_credentials(str(err)),
260248
)
261-
return None
262-
else:
263-
return maker
249+
250+
return None
264251

265252

266253
def generate_lambda_stmt(query: str) -> StatementLambdaElement:

tests/components/sql/test_sensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ def execute(self, query: Any) -> None:
325325
"""Execute the query."""
326326
raise SQLAlchemyError("sqlite://homeassistant:[email protected]")
327327

328+
def rollback(self) -> None:
329+
pass
330+
328331
with patch(
329332
"homeassistant.components.sql.util.scoped_session",
330333
return_value=MockSession,

tests/components/sql/test_services.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ async def test_query_service_invalid_db_url(hass: HomeAssistant) -> None:
230230

231231
with (
232232
patch(
233-
"homeassistant.components.sql.util._validate_and_get_session_maker_for_db_url",
233+
"homeassistant.components.sql.util._async_validate_and_get_session_maker_for_db_url",
234234
return_value=None,
235235
),
236236
pytest.raises(

0 commit comments

Comments
 (0)