Skip to content

Commit e06c57a

Browse files
committed
Add support for aiomysql, aiosqlite and asyncpg
Signed-off-by: David Rapan <[email protected]>
1 parent 39d76a2 commit e06c57a

File tree

9 files changed

+238
-68
lines changed

9 files changed

+238
-68
lines changed

homeassistant/components/sql/manifest.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,12 @@
66
"config_flow": true,
77
"documentation": "https://www.home-assistant.io/integrations/sql",
88
"iot_class": "local_polling",
9-
"requirements": ["SQLAlchemy==2.0.41", "sqlparse==0.5.0"]
9+
"requirements": [
10+
"aiomysql==0.3.2",
11+
"aiosqlite==0.21.0",
12+
"asyncpg==0.30.0",
13+
"greenlet==3.2.4",
14+
"SQLAlchemy==2.0.41",
15+
"sqlparse==0.5.0"
16+
]
1017
}

homeassistant/components/sql/models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from dataclasses import dataclass
66

7-
from sqlalchemy.orm import scoped_session
7+
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
8+
from sqlalchemy.orm import Session, scoped_session
89

910
from homeassistant.core import CALLBACK_TYPE
1011

@@ -14,4 +15,6 @@ class SQLData:
1415
"""Data for the sql integration."""
1516

1617
shutdown_event_cancel: CALLBACK_TYPE
17-
session_makers_by_db_url: dict[str, scoped_session]
18+
session_makers_by_db_url: dict[
19+
str, async_scoped_session[AsyncSession] | scoped_session[Session]
20+
]

homeassistant/components/sql/sensor.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from datetime import date
66
import decimal
77
import logging
8-
from typing import Any
8+
from typing import TYPE_CHECKING, Any
99

1010
from sqlalchemy.engine import Result
1111
from sqlalchemy.exc import SQLAlchemyError
12-
from sqlalchemy.orm import scoped_session
12+
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
13+
from sqlalchemy.orm import Session, scoped_session
1314

1415
from homeassistant.components.recorder import CONF_DB_URL, get_instance
1516
from homeassistant.components.sensor import CONF_STATE_CLASS
@@ -200,7 +201,7 @@ class SQLSensor(ManualTriggerSensorEntity):
200201
def __init__(
201202
self,
202203
trigger_entity_config: ConfigType,
203-
sessmaker: scoped_session,
204+
sessmaker: async_scoped_session[AsyncSession] | scoped_session[Session],
204205
query: str,
205206
column: str,
206207
value_template: ValueTemplate | None,
@@ -243,31 +244,10 @@ def extra_state_attributes(self) -> dict[str, Any] | None:
243244
"""Return extra attributes."""
244245
return dict(self._attr_extra_state_attributes)
245246

246-
async def async_update(self) -> None:
247-
"""Retrieve sensor data from the query using the right executor."""
248-
if self._use_database_executor:
249-
await get_instance(self.hass).async_add_executor_job(self._update)
250-
else:
251-
await self.hass.async_add_executor_job(self._update)
252-
253-
def _update(self) -> None:
254-
"""Retrieve sensor data from the query."""
247+
def _process(self, result: Result) -> None:
248+
"""Process the SQL result."""
255249
data = None
256250
extra_state_attributes = {}
257-
self._attr_extra_state_attributes = {}
258-
sess: scoped_session = self.sessionmaker()
259-
try:
260-
result: Result = sess.execute(self._lambda_stmt)
261-
except SQLAlchemyError as err:
262-
_LOGGER.error(
263-
"Error executing query %s: %s",
264-
self._query,
265-
redact_credentials(str(err)),
266-
)
267-
sess.rollback()
268-
sess.close()
269-
return
270-
271251
for res in result.mappings():
272252
_LOGGER.debug("Query %s result in %s", self._query, res.items())
273253
data = res[self._column_name]
@@ -298,4 +278,35 @@ def _update(self) -> None:
298278
if data is None:
299279
_LOGGER.warning("%s returned no results", self._query)
300280

301-
sess.close()
281+
def _update(self) -> None:
282+
"""Retrieve sensor data from the query."""
283+
self._attr_extra_state_attributes = {}
284+
try:
285+
if TYPE_CHECKING:
286+
assert isinstance(self.sessionmaker, scoped_session)
287+
with self.sessionmaker() as session:
288+
self._process(session.execute(self._lambda_stmt))
289+
except SQLAlchemyError as err:
290+
_LOGGER.error(
291+
"Error executing query %s: %s",
292+
self._query,
293+
redact_credentials(str(err)),
294+
)
295+
296+
async def async_update(self) -> None:
297+
"""Retrieve sensor data from the query using the right executor."""
298+
if isinstance(self.sessionmaker, async_scoped_session):
299+
self._attr_extra_state_attributes = {}
300+
try:
301+
async with self.sessionmaker() as session:
302+
self._process(await session.execute(self._lambda_stmt))
303+
except SQLAlchemyError as err:
304+
_LOGGER.error(
305+
"Error executing query %s: %s",
306+
self._query,
307+
redact_credentials(str(err)),
308+
)
309+
elif self._use_database_executor:
310+
await get_instance(self.hass).async_add_executor_job(self._update)
311+
else:
312+
await self.hass.async_add_executor_job(self._update)

homeassistant/components/sql/services.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import datetime
66
import decimal
77
import logging
8+
from typing import TYPE_CHECKING
89

910
from sqlalchemy.engine import Result
1011
from sqlalchemy.exc import SQLAlchemyError
11-
from sqlalchemy.orm import Session
12+
from sqlalchemy.ext.asyncio import async_scoped_session
13+
from sqlalchemy.orm import scoped_session
1214
import voluptuous as vol
1315

1416
from homeassistant.components.recorder import CONF_DB_URL, get_instance
@@ -70,39 +72,52 @@ async def _async_query_service(
7072
translation_placeholders={"db_url": redact_credentials(db_url)},
7173
)
7274

75+
def _process(result: Result) -> list[JsonValueType]:
76+
rows: list[JsonValueType] = []
77+
for row in result.mappings():
78+
processed_row: dict[str, JsonValueType] = {}
79+
for key, value in row.items():
80+
if isinstance(value, decimal.Decimal):
81+
processed_row[key] = float(value)
82+
elif isinstance(value, datetime.date):
83+
processed_row[key] = value.isoformat()
84+
elif isinstance(value, (bytes, bytearray)):
85+
processed_row[key] = f"0x{value.hex()}"
86+
else:
87+
processed_row[key] = value
88+
rows.append(processed_row)
89+
return rows
90+
7391
def _execute_and_convert_query() -> list[JsonValueType]:
7492
"""Execute the query and return the results with converted types."""
75-
sess: Session = sessmaker()
7693
try:
77-
result: Result = sess.execute(generate_lambda_stmt(query_str))
94+
if TYPE_CHECKING:
95+
assert isinstance(sessmaker, scoped_session)
96+
with sessmaker() as session:
97+
return _process(session.execute(generate_lambda_stmt(query_str)))
7898
except SQLAlchemyError as err:
7999
_LOGGER.debug(
80100
"Error executing query %s: %s",
81101
query_str,
82102
redact_credentials(str(err)),
83103
)
84-
sess.rollback()
85104
raise
86-
else:
87-
rows: list[JsonValueType] = []
88-
for row in result.mappings():
89-
processed_row: dict[str, JsonValueType] = {}
90-
for key, value in row.items():
91-
if isinstance(value, decimal.Decimal):
92-
processed_row[key] = float(value)
93-
elif isinstance(value, datetime.date):
94-
processed_row[key] = value.isoformat()
95-
elif isinstance(value, (bytes, bytearray)):
96-
processed_row[key] = f"0x{value.hex()}"
97-
else:
98-
processed_row[key] = value
99-
rows.append(processed_row)
100-
return rows
101-
finally:
102-
sess.close()
103105

104106
try:
105-
if use_database_executor:
107+
if isinstance(sessmaker, async_scoped_session):
108+
try:
109+
async with sessmaker() as session:
110+
result = _process(
111+
await session.execute(generate_lambda_stmt(query_str))
112+
)
113+
except SQLAlchemyError as err:
114+
_LOGGER.debug(
115+
"Error executing query %s: %s",
116+
query_str,
117+
redact_credentials(str(err)),
118+
)
119+
raise
120+
elif use_database_executor:
106121
result = await get_instance(call.hass).async_add_executor_job(
107122
_execute_and_convert_query
108123
)

homeassistant/components/sql/util.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,18 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import logging
67

78
import sqlalchemy
89
from sqlalchemy import lambda_stmt
910
from sqlalchemy.exc import SQLAlchemyError
11+
from sqlalchemy.ext.asyncio import (
12+
AsyncSession,
13+
async_scoped_session,
14+
async_sessionmaker,
15+
create_async_engine,
16+
)
1017
from sqlalchemy.orm import Session, scoped_session, sessionmaker
1118
from sqlalchemy.sql.lambdas import StatementLambdaElement
1219
from sqlalchemy.util import LRUCache
@@ -55,7 +62,9 @@ def validate_sql_select(value: str) -> str:
5562

5663
async def async_create_sessionmaker(
5764
hass: HomeAssistant, db_url: str
58-
) -> tuple[scoped_session | None, bool, bool]:
65+
) -> tuple[
66+
async_scoped_session[AsyncSession] | scoped_session[Session] | None, bool, bool
67+
]:
5968
"""Create a session maker for the given db_url.
6069
6170
This function gets or creates a SQLAlchemy `scoped_session` for the given
@@ -83,7 +92,7 @@ async def async_create_sessionmaker(
8392
uses_recorder_db = False
8493
else:
8594
uses_recorder_db = db_url == instance.db_url
86-
sessmaker: scoped_session | None
95+
sessmaker: async_scoped_session[AsyncSession] | scoped_session[Session] | None
8796
sql_data = _async_get_or_init_domain_data(hass)
8897
use_database_executor = False
8998
if uses_recorder_db and instance.dialect_name == SupportedDialect.SQLITE:
@@ -98,6 +107,9 @@ async def async_create_sessionmaker(
98107
# for every sensor.
99108
elif db_url in sql_data.session_makers_by_db_url:
100109
sessmaker = sql_data.session_makers_by_db_url[db_url]
110+
elif "+aiomysql" in db_url or "+aiosqlite" in db_url or "+asyncpg" in db_url:
111+
if sessmaker := await _async_validate_and_get_session_maker_for_db_url(db_url):
112+
sql_data.session_makers_by_db_url[db_url] = sessmaker
101113
elif sessmaker := await hass.async_add_executor_job(
102114
_validate_and_get_session_maker_for_db_url, db_url
103115
):
@@ -169,7 +181,9 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData:
169181
sql_data: SQLData = hass.data[DOMAIN]
170182
return sql_data
171183

172-
session_makers_by_db_url: dict[str, scoped_session] = {}
184+
session_makers_by_db_url: dict[
185+
str, async_scoped_session[AsyncSession] | scoped_session[Session]
186+
] = {}
173187

174188
#
175189
# Ensure we dispose of all engines at shutdown
@@ -178,10 +192,13 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData:
178192
# Shutdown all sessions in the executor since they will
179193
# do blocking I/O
180194
#
181-
def _shutdown_db_engines(event: Event) -> None:
195+
async def _shutdown_db_engines(_: Event) -> None:
182196
"""Shutdown all database engines."""
183197
for sessmaker in session_makers_by_db_url.values():
184-
sessmaker.connection().engine.dispose()
198+
if isinstance(sessmaker, async_scoped_session):
199+
await (await sessmaker.connection()).engine.dispose()
200+
else:
201+
sessmaker.connection().engine.dispose()
185202

186203
cancel_shutdown = hass.bus.async_listen_once(
187204
EVENT_HOMEASSISTANT_STOP, _shutdown_db_engines
@@ -192,18 +209,46 @@ def _shutdown_db_engines(event: Event) -> None:
192209
return sql_data
193210

194211

195-
def _validate_and_get_session_maker_for_db_url(db_url: str) -> scoped_session | None:
212+
async def _async_validate_and_get_session_maker_for_db_url(
213+
db_url: str,
214+
) -> async_scoped_session[AsyncSession] | None:
215+
"""Validate the db_url and return a async session maker."""
216+
try:
217+
maker = async_scoped_session(
218+
async_sessionmaker(bind=create_async_engine(db_url)),
219+
scopefunc=asyncio.current_task,
220+
)
221+
# Run a dummy query just to test the db_url
222+
async with maker() as session:
223+
await session.execute(sqlalchemy.text("SELECT 1;"))
224+
225+
except SQLAlchemyError as err:
226+
_LOGGER.error(
227+
"Couldn't connect using %s DB_URL: %s",
228+
redact_credentials(db_url),
229+
redact_credentials(str(err)),
230+
)
231+
return None
232+
else:
233+
return maker
234+
235+
236+
def _validate_and_get_session_maker_for_db_url(
237+
db_url: str,
238+
) -> scoped_session[Session] | None:
196239
"""Validate the db_url and return a session maker.
197240
198241
This does I/O and should be run in the executor.
199242
"""
200-
sess: Session | None = None
201243
try:
202-
engine = sqlalchemy.create_engine(db_url, future=True)
203-
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
244+
maker = scoped_session(
245+
sessionmaker(
246+
bind=sqlalchemy.create_engine(db_url, future=True), future=True
247+
)
248+
)
204249
# Run a dummy query just to test the db_url
205-
sess = sessmaker()
206-
sess.execute(sqlalchemy.text("SELECT 1;"))
250+
with maker() as session:
251+
session.execute(sqlalchemy.text("SELECT 1;"))
207252

208253
except SQLAlchemyError as err:
209254
_LOGGER.error(
@@ -213,10 +258,7 @@ def _validate_and_get_session_maker_for_db_url(db_url: str) -> scoped_session |
213258
)
214259
return None
215260
else:
216-
return sessmaker
217-
finally:
218-
if sess:
219-
sess.close()
261+
return maker
220262

221263

222264
def generate_lambda_stmt(query: str) -> StatementLambdaElement:

requirements_all.txt

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)