diff --git a/homeassistant/components/sql/manifest.json b/homeassistant/components/sql/manifest.json index 244334565657ea..9517bd2f3e4879 100644 --- a/homeassistant/components/sql/manifest.json +++ b/homeassistant/components/sql/manifest.json @@ -6,5 +6,11 @@ "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/sql", "iot_class": "local_polling", - "requirements": ["SQLAlchemy==2.0.41", "sqlparse==0.5.0"] + "requirements": [ + "aiomysql==0.3.2", + "aiosqlite==0.21.0", + "asyncpg==0.30.0", + "SQLAlchemy[asyncio]==2.0.41", + "sqlparse==0.5.0" + ] } diff --git a/homeassistant/components/sql/models.py b/homeassistant/components/sql/models.py index 872ceedde71b9a..45ad238573f0a5 100644 --- a/homeassistant/components/sql/models.py +++ b/homeassistant/components/sql/models.py @@ -4,7 +4,8 @@ from dataclasses import dataclass -from sqlalchemy.orm import scoped_session +from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session +from sqlalchemy.orm import Session, scoped_session from homeassistant.core import CALLBACK_TYPE @@ -14,4 +15,6 @@ class SQLData: """Data for the sql integration.""" shutdown_event_cancel: CALLBACK_TYPE - session_makers_by_db_url: dict[str, scoped_session] + session_makers_by_db_url: dict[ + str, async_scoped_session[AsyncSession] | scoped_session[Session] + ] diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index 508365b5c0dce0..42d311c985c762 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -2,14 +2,13 @@ from __future__ import annotations -from datetime import date -import decimal import logging -from typing import Any +from typing import TYPE_CHECKING, Any from sqlalchemy.engine import Result from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import scoped_session +from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session +from sqlalchemy.orm import Session, scoped_session from homeassistant.components.recorder import CONF_DB_URL, get_instance from homeassistant.components.sensor import CONF_STATE_CLASS @@ -43,6 +42,7 @@ from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN from .util import ( async_create_sessionmaker, + ensure_serializable, generate_lambda_stmt, redact_credentials, resolve_db_url, @@ -200,7 +200,7 @@ class SQLSensor(ManualTriggerSensorEntity): def __init__( self, trigger_entity_config: ConfigType, - sessmaker: scoped_session, + sessmaker: async_scoped_session[AsyncSession] | scoped_session[Session], query: str, column: str, value_template: ValueTemplate | None, @@ -243,43 +243,16 @@ def extra_state_attributes(self) -> dict[str, Any] | None: """Return extra attributes.""" return dict(self._attr_extra_state_attributes) - async def async_update(self) -> None: - """Retrieve sensor data from the query using the right executor.""" - if self._use_database_executor: - await get_instance(self.hass).async_add_executor_job(self._update) - else: - await self.hass.async_add_executor_job(self._update) - - def _update(self) -> None: - """Retrieve sensor data from the query.""" + def _process(self, result: Result) -> None: + """Process the SQL result.""" data = None - extra_state_attributes = {} - self._attr_extra_state_attributes = {} - sess: scoped_session = self.sessionmaker() - try: - result: Result = sess.execute(self._lambda_stmt) - except SQLAlchemyError as err: - _LOGGER.error( - "Error executing query %s: %s", - self._query, - redact_credentials(str(err)), - ) - sess.rollback() - sess.close() - return - - for res in result.mappings(): - _LOGGER.debug("Query %s result in %s", self._query, res.items()) - data = res[self._column_name] - for key, value in res.items(): - if isinstance(value, decimal.Decimal): - value = float(value) - elif isinstance(value, date): - value = value.isoformat() - elif isinstance(value, (bytes, bytearray)): - value = f"0x{value.hex()}" - extra_state_attributes[key] = value - self._attr_extra_state_attributes[key] = value + + for row in result.mappings(): + row_items = row.items() + _LOGGER.debug("Query %s result in %s", self._query, row_items) + data = row[self._column_name] + for key, value in row_items: + self._attr_extra_state_attributes[key] = ensure_serializable(value) if data is not None and isinstance(data, (bytes, bytearray)): data = f"0x{data.hex()}" @@ -298,4 +271,39 @@ def _update(self) -> None: if data is None: _LOGGER.warning("%s returned no results", self._query) - sess.close() + def _update(self) -> None: + """Retrieve sensor data from the query. + + This does I/O and should be run in the executor. + """ + if TYPE_CHECKING: + assert isinstance(self.sessionmaker, scoped_session) + with self.sessionmaker() as session: + try: + self._process(session.execute(self._lambda_stmt)) + except SQLAlchemyError as err: + _LOGGER.error( + "Error executing query %s: %s", + self._query, + redact_credentials(str(err)), + ) + session.rollback() + + async def async_update(self) -> None: + """Retrieve sensor data from the query using the right executor.""" + self._attr_extra_state_attributes = {} + if isinstance(self.sessionmaker, async_scoped_session): + async with self.sessionmaker() as session: + try: + self._process(await session.execute(self._lambda_stmt)) + except SQLAlchemyError as err: + _LOGGER.error( + "Error executing query %s: %s", + self._query, + redact_credentials(str(err)), + ) + await session.rollback() + elif self._use_database_executor: + await get_instance(self.hass).async_add_executor_job(self._update) + else: + await self.hass.async_add_executor_job(self._update) diff --git a/homeassistant/components/sql/services.py b/homeassistant/components/sql/services.py index c7b74bd82b6ed4..9fd6dd8f669484 100644 --- a/homeassistant/components/sql/services.py +++ b/homeassistant/components/sql/services.py @@ -2,13 +2,13 @@ from __future__ import annotations -import datetime -import decimal import logging +from typing import TYPE_CHECKING from sqlalchemy.engine import Result from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.orm import scoped_session import voluptuous as vol from homeassistant.components.recorder import CONF_DB_URL, get_instance @@ -26,6 +26,7 @@ from .const import CONF_QUERY, DOMAIN from .util import ( async_create_sessionmaker, + ensure_serializable, generate_lambda_stmt, redact_credentials, resolve_db_url, @@ -70,39 +71,47 @@ async def _async_query_service( translation_placeholders={"db_url": redact_credentials(db_url)}, ) + def _process(result: Result) -> list[JsonValueType]: + rows: list[JsonValueType] = [] + for row in result.mappings(): + processed_row: dict[str, JsonValueType] = {} + for key, value in row.items(): + processed_row[key] = ensure_serializable(value) + rows.append(processed_row) + return rows + def _execute_and_convert_query() -> list[JsonValueType]: """Execute the query and return the results with converted types.""" - sess: Session = sessmaker() - try: - result: Result = sess.execute(generate_lambda_stmt(query_str)) - except SQLAlchemyError as err: - _LOGGER.debug( - "Error executing query %s: %s", - query_str, - redact_credentials(str(err)), - ) - sess.rollback() - raise - else: - rows: list[JsonValueType] = [] - for row in result.mappings(): - processed_row: dict[str, JsonValueType] = {} - for key, value in row.items(): - if isinstance(value, decimal.Decimal): - processed_row[key] = float(value) - elif isinstance(value, datetime.date): - processed_row[key] = value.isoformat() - elif isinstance(value, (bytes, bytearray)): - processed_row[key] = f"0x{value.hex()}" - else: - processed_row[key] = value - rows.append(processed_row) - return rows - finally: - sess.close() + if TYPE_CHECKING: + assert isinstance(sessmaker, scoped_session) + with sessmaker() as session: + try: + return _process(session.execute(generate_lambda_stmt(query_str))) + except SQLAlchemyError as err: + _LOGGER.debug( + "Error executing query %s: %s", + query_str, + redact_credentials(str(err)), + ) + session.rollback() + raise try: - if use_database_executor: + if isinstance(sessmaker, async_scoped_session): + async with sessmaker() as session: + try: + result = _process( + await session.execute(generate_lambda_stmt(query_str)) + ) + except SQLAlchemyError as err: + _LOGGER.debug( + "Error executing query %s: %s", + query_str, + redact_credentials(str(err)), + ) + await session.rollback() + raise + elif use_database_executor: result = await get_instance(call.hass).async_add_executor_job( _execute_and_convert_query ) diff --git a/homeassistant/components/sql/util.py b/homeassistant/components/sql/util.py index 0200a83c9e8499..f99c7aae4a83bd 100644 --- a/homeassistant/components/sql/util.py +++ b/homeassistant/components/sql/util.py @@ -2,11 +2,21 @@ from __future__ import annotations +import asyncio +from datetime import date +from decimal import Decimal import logging +from typing import Any import sqlalchemy from sqlalchemy import lambda_stmt from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_scoped_session, + async_sessionmaker, + create_async_engine, +) from sqlalchemy.orm import Session, scoped_session, sessionmaker from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.util import LRUCache @@ -23,6 +33,7 @@ _LOGGER = logging.getLogger(__name__) +_SQL_SELECT = sqlalchemy.text("SELECT 1;") _SQL_LAMBDA_CACHE: LRUCache = LRUCache(1000) @@ -55,7 +66,9 @@ def validate_sql_select(value: str) -> str: async def async_create_sessionmaker( hass: HomeAssistant, db_url: str -) -> tuple[scoped_session | None, bool, bool]: +) -> tuple[ + async_scoped_session[AsyncSession] | scoped_session[Session] | None, bool, bool +]: """Create a session maker for the given db_url. This function gets or creates a SQLAlchemy `scoped_session` for the given @@ -83,7 +96,7 @@ async def async_create_sessionmaker( uses_recorder_db = False else: uses_recorder_db = db_url == instance.db_url - sessmaker: scoped_session | None + sessmaker: async_scoped_session[AsyncSession] | scoped_session[Session] | None sql_data = _async_get_or_init_domain_data(hass) use_database_executor = False if uses_recorder_db and instance.dialect_name == SupportedDialect.SQLITE: @@ -98,8 +111,8 @@ async def async_create_sessionmaker( # for every sensor. elif db_url in sql_data.session_makers_by_db_url: sessmaker = sql_data.session_makers_by_db_url[db_url] - elif sessmaker := await hass.async_add_executor_job( - _validate_and_get_session_maker_for_db_url, db_url + elif sessmaker := await _async_validate_and_get_session_maker_for_db_url( + hass, db_url ): sql_data.session_makers_by_db_url[db_url] = sessmaker else: @@ -169,7 +182,9 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData: sql_data: SQLData = hass.data[DOMAIN] return sql_data - session_makers_by_db_url: dict[str, scoped_session] = {} + session_makers_by_db_url: dict[ + str, async_scoped_session[AsyncSession] | scoped_session[Session] + ] = {} # # Ensure we dispose of all engines at shutdown @@ -178,9 +193,14 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData: # Shutdown all sessions in the executor since they will # do blocking I/O # - def _shutdown_db_engines(event: Event) -> None: + async def _shutdown_db_engines(_: Event) -> None: """Shutdown all database engines.""" for sessmaker in session_makers_by_db_url.values(): + if isinstance(sessmaker, async_scoped_session): + _LOGGER.error("Disposed async engine for shutdown 1") + await (await sessmaker.connection()).engine.dispose() + _LOGGER.error("Disposed async engine for shutdown 2") + raise SQLAlchemyError("Disposed async engine for shutdown") sessmaker.connection().engine.dispose() cancel_shutdown = hass.bus.async_listen_once( @@ -192,34 +212,63 @@ def _shutdown_db_engines(event: Event) -> None: return sql_data -def _validate_and_get_session_maker_for_db_url(db_url: str) -> scoped_session | None: - """Validate the db_url and return a session maker. - - This does I/O and should be run in the executor. - """ - sess: Session | None = None +async def _async_validate_and_get_session_maker_for_db_url( + hass: HomeAssistant, db_url: str +) -> async_scoped_session[AsyncSession] | scoped_session[Session] | None: + """Validate the db_url and return a session maker.""" try: - engine = sqlalchemy.create_engine(db_url, future=True) - sessmaker = scoped_session(sessionmaker(bind=engine, future=True)) - # Run a dummy query just to test the db_url - sess = sessmaker() - sess.execute(sqlalchemy.text("SELECT 1;")) - + if "+aiomysql" in db_url or "+aiosqlite" in db_url or "+asyncpg" in db_url: + maker = async_scoped_session( + async_sessionmaker( + bind=create_async_engine(db_url, future=True), future=True + ), + scopefunc=asyncio.current_task, + ) + # Run a dummy query just to test the db_url + async with maker() as session: + await session.execute(_SQL_SELECT) + return maker + + def _get_session_maker_for_db_url() -> scoped_session[Session] | None: + """Validate the db_url and return a session maker. + + This does I/O and should be run in the executor. + """ + maker = scoped_session( + sessionmaker( + bind=sqlalchemy.create_engine(db_url, future=True), future=True + ) + ) + # Run a dummy query just to test the db_url + with maker() as session: + session.execute(_SQL_SELECT) + return maker + + return await hass.async_add_executor_job(_get_session_maker_for_db_url) except SQLAlchemyError as err: _LOGGER.error( "Couldn't connect using %s DB_URL: %s", redact_credentials(db_url), redact_credentials(str(err)), ) - return None - else: - return sessmaker - finally: - if sess: - sess.close() + + return None def generate_lambda_stmt(query: str) -> StatementLambdaElement: """Generate the lambda statement.""" text = sqlalchemy.text(query) return lambda_stmt(lambda: text, lambda_cache=_SQL_LAMBDA_CACHE) + + +def ensure_serializable(value: Any) -> Any: + """Ensure value is serializable.""" + match value: + case Decimal(): + return float(value) + case date(): + return value.isoformat() + case bytes() | bytearray(): + return f"0x{value.hex()}" + case _: + return value diff --git a/requirements_all.txt b/requirements_all.txt index f558d76a0e5d52..2bd65b88ce3147 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -118,9 +118,11 @@ RestrictedPython==8.1 RtmAPI==0.7.2 # homeassistant.components.recorder -# homeassistant.components.sql SQLAlchemy==2.0.41 +# homeassistant.components.sql +SQLAlchemy[asyncio]==2.0.41 + # homeassistant.components.tami4 Tami4EdgeAPI==3.0 @@ -327,6 +329,9 @@ aiomodernforms==0.1.8 # homeassistant.components.yamaha_musiccast aiomusiccast==0.14.8 +# homeassistant.components.sql +aiomysql==0.3.2 + # homeassistant.components.nanoleaf aionanoleaf==0.2.1 @@ -404,6 +409,9 @@ aioslimproto==3.0.0 # homeassistant.components.solaredge aiosolaredge==0.2.0 +# homeassistant.components.sql +aiosqlite==0.21.0 + # homeassistant.components.steamist aiosteamist==1.0.1 @@ -553,6 +561,9 @@ asyncarve==0.1.1 # homeassistant.components.keyboard_remote asyncinotify==4.2.0 +# homeassistant.components.sql +asyncpg==0.30.0 + # homeassistant.components.supla asyncpysupla==0.0.5 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index a0d00fe70361e0..3f010afda634ed 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -112,9 +112,11 @@ RestrictedPython==8.1 RtmAPI==0.7.2 # homeassistant.components.recorder -# homeassistant.components.sql SQLAlchemy==2.0.41 +# homeassistant.components.sql +SQLAlchemy[asyncio]==2.0.41 + # homeassistant.components.tami4 Tami4EdgeAPI==3.0 @@ -309,6 +311,9 @@ aiomodernforms==0.1.8 # homeassistant.components.yamaha_musiccast aiomusiccast==0.14.8 +# homeassistant.components.sql +aiomysql==0.3.2 + # homeassistant.components.nanoleaf aionanoleaf==0.2.1 @@ -386,6 +391,9 @@ aioslimproto==3.0.0 # homeassistant.components.solaredge aiosolaredge==0.2.0 +# homeassistant.components.sql +aiosqlite==0.21.0 + # homeassistant.components.steamist aiosteamist==1.0.1 @@ -514,6 +522,9 @@ async-upnp-client==0.45.0 # homeassistant.components.arve asyncarve==0.1.1 +# homeassistant.components.sql +asyncpg==0.30.0 + # homeassistant.components.sleepiq asyncsleepiq==1.6.0 diff --git a/script/hassfest/requirements.py b/script/hassfest/requirements.py index f1048b866e22bf..6be1162049ca19 100644 --- a/script/hassfest/requirements.py +++ b/script/hassfest/requirements.py @@ -324,6 +324,8 @@ }, # https://github.com/smappee/pysmappee "smappee": {"homeassistant": {"pysmappee"}}, + # https://github.com/aio-libs/aiomysql + "sql": {"homeassistant": {"aiomysql"}}, # https://github.com/watergate-ai/watergate-local-api-python "watergate": {"homeassistant": {"watergate-local-api"}}, # https://github.com/markusressel/xs1-api-client diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index 73879065999f63..d66ac28f6b12f8 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -2,10 +2,12 @@ from __future__ import annotations +import copy from datetime import timedelta from pathlib import Path import sqlite3 -from typing import Any +import types +from typing import Any, Self from unittest.mock import patch from freezegun.api import FrozenDateTimeFactory @@ -189,14 +191,28 @@ def make_test_db(): @pytest.mark.parametrize( - ("url", "expected_patterns", "not_expected_patterns"), + ("patch_create", "url", "expected_patterns", "not_expected_patterns"), [ ( + "homeassistant.components.sql.util.create_async_engine", + "sqlite+aiosqlite://homeassistant:hunter2@homeassistant.local", + ["sqlite+aiosqlite://****:****@homeassistant.local"], + ["sqlite+aiosqlite://homeassistant:hunter2@homeassistant.local"], + ), + ( + "homeassistant.components.sql.util.create_async_engine", + "sqlite+aiosqlite://homeassistant.local", + ["sqlite+aiosqlite://homeassistant.local"], + [], + ), + ( + "homeassistant.components.sql.util.sqlalchemy.create_engine", "sqlite://homeassistant:hunter2@homeassistant.local", ["sqlite://****:****@homeassistant.local"], ["sqlite://homeassistant:hunter2@homeassistant.local"], ), ( + "homeassistant.components.sql.util.sqlalchemy.create_engine", "sqlite://homeassistant.local", ["sqlite://homeassistant.local"], [], @@ -207,6 +223,7 @@ async def test_invalid_url_setup( recorder_mock: Recorder, hass: HomeAssistant, caplog: pytest.LogCaptureFixture, + patch_create: str, url: str, expected_patterns: str, not_expected_patterns: str, @@ -228,10 +245,7 @@ async def test_invalid_url_setup( entry.add_to_hass(hass) - with patch( - "homeassistant.components.sql.util.sqlalchemy.create_engine", - side_effect=SQLAlchemyError(url), - ): + with patch(patch_create, side_effect=SQLAlchemyError(url)): await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() @@ -255,10 +269,24 @@ async def test_invalid_url_on_update( class MockSession: """Mock session.""" + def __enter__(self) -> Self: + return self + + def __exit__( + self, + type: type[BaseException] | None, + value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + pass + def execute(self, query: Any) -> None: """Execute the query.""" raise SQLAlchemyError("sqlite://homeassistant:hunter2@homeassistant.local") + def rollback(self) -> None: + pass + with patch( "homeassistant.components.sql.util.scoped_session", return_value=MockSession, @@ -273,10 +301,19 @@ def execute(self, query: Any) -> None: assert "sqlite://****:****@homeassistant.local" in caplog.text -async def test_query_from_yaml(recorder_mock: Recorder, hass: HomeAssistant) -> None: +@pytest.mark.parametrize("async_driver", [True, False]) +async def test_query_from_yaml( + recorder_mock: Recorder, hass: HomeAssistant, async_driver: bool +) -> None: """Test the SQL sensor from yaml config.""" - assert await async_setup_component(hass, DOMAIN, YAML_CONFIG) + config = YAML_CONFIG + + if async_driver: + config = copy.deepcopy(YAML_CONFIG) + config["sql"][CONF_DB_URL] = "sqlite+aiosqlite://" + + assert await async_setup_component(hass, DOMAIN, config) await hass.async_block_till_done() state = hass.states.get("sensor.get_value") @@ -372,14 +409,28 @@ async def test_config_from_old_yaml( @pytest.mark.parametrize( - ("url", "expected_patterns", "not_expected_patterns"), + ("patch_create", "url", "expected_patterns", "not_expected_patterns"), [ ( + "homeassistant.components.sql.util.create_async_engine", + "sqlite+aiosqlite://homeassistant:hunter2@homeassistant.local", + ["sqlite+aiosqlite://****:****@homeassistant.local"], + ["sqlite+aiosqlite://homeassistant:hunter2@homeassistant.local"], + ), + ( + "homeassistant.components.sql.util.create_async_engine", + "sqlite+aiosqlite://homeassistant.local", + ["sqlite+aiosqlite://homeassistant.local"], + [], + ), + ( + "homeassistant.components.sql.util.sqlalchemy.create_engine", "sqlite://homeassistant:hunter2@homeassistant.local", ["sqlite://****:****@homeassistant.local"], ["sqlite://homeassistant:hunter2@homeassistant.local"], ), ( + "homeassistant.components.sql.util.sqlalchemy.create_engine", "sqlite://homeassistant.local", ["sqlite://homeassistant.local"], [], @@ -390,6 +441,7 @@ async def test_invalid_url_setup_from_yaml( recorder_mock: Recorder, hass: HomeAssistant, caplog: pytest.LogCaptureFixture, + patch_create: str, url: str, expected_patterns: str, not_expected_patterns: str, @@ -404,11 +456,9 @@ async def test_invalid_url_setup_from_yaml( } } - with patch( - "homeassistant.components.sql.util.sqlalchemy.create_engine", - side_effect=SQLAlchemyError(url), - ): + with patch(patch_create, side_effect=SQLAlchemyError(url)): assert await async_setup_component(hass, DOMAIN, config) + await hass.async_block_till_done() for pattern in not_expected_patterns: @@ -557,6 +607,53 @@ async def test_multiple_sensors_using_same_db( await hass.async_stop() +async def test_multiple_sensors_using_same_external_db( + recorder_mock: Recorder, hass: HomeAssistant, tmp_path: Path +) -> None: + """Test multiple sensors using the same external db.""" + db_path = tmp_path / "test.db" + + # Create and populate the external database + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE users (name TEXT, age INTEGER)") + conn.execute("INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)") + conn.commit() + conn.close() + + config = {CONF_DB_URL: f"sqlite:///{db_path}"} + config2 = {CONF_DB_URL: f"sqlite:///{db_path}"} + options = { + CONF_QUERY: "SELECT name FROM users ORDER BY age LIMIT 1", + CONF_COLUMN_NAME: "name", + } + options2 = { + CONF_QUERY: "SELECT name FROM users ORDER BY age DESC LIMIT 1", + CONF_COLUMN_NAME: "name", + } + + await init_integration( + hass, title="Select name SQL query", config=config, options=options + ) + + assert hass.data["sql"] + assert len(hass.data["sql"].session_makers_by_db_url) == 1 + assert hass.states.get("sensor.select_name_sql_query").state == "Bob" + + await init_integration( + hass, + title="Select name SQL query 2", + config=config2, + options=options2, + entry_id="2", + ) + + assert len(hass.data["sql"].session_makers_by_db_url) == 1 + assert hass.states.get("sensor.select_name_sql_query_2").state == "Alice" + + with patch("sqlalchemy.engine.base.Engine.dispose"): + await hass.async_stop() + + async def test_engine_is_disposed_at_stop( recorder_mock: Recorder, hass: HomeAssistant ) -> None: @@ -628,6 +725,59 @@ async def test_attributes_from_entry_config( assert CONF_STATE_CLASS not in state.attributes +@pytest.mark.parametrize( + ("config", "patch_rollback"), + [ + ( + {CONF_DB_URL: "sqlite+aiosqlite:///"}, + "sqlalchemy.ext.asyncio.session.AsyncSession.rollback", + ), + ( + {}, + "sqlalchemy.orm.session.Session.rollback", + ), + ], +) +async def test_query_rollback_on_error( + recorder_mock: Recorder, + hass: HomeAssistant, + freezer: FrozenDateTimeFactory, + caplog: pytest.LogCaptureFixture, + config: dict[str, Any], + patch_rollback: str, +) -> None: + """Test the SQL sensor.""" + options = { + CONF_QUERY: "SELECT 5 as value", + CONF_COLUMN_NAME: "value", + CONF_UNIQUE_ID: "very_unique_id", + } + await init_integration( + hass, title="Select value SQL query", config=config, options=options + ) + platforms = async_get_platforms(hass, "sql") + sql_entity = platforms[0].entities["sensor.select_value_sql_query"] + + state = hass.states.get("sensor.select_value_sql_query") + assert state.state == "5" + assert state.attributes["value"] == 5 + + with ( + patch.object( + sql_entity, + "_lambda_stmt", + generate_lambda_stmt("Faulty syntax create operational issue"), + ), + patch(patch_rollback) as mock_rollback, + ): + freezer.tick(timedelta(minutes=1)) + async_fire_time_changed(hass) + await hass.async_block_till_done(wait_background_tasks=True) + assert "sqlite3.OperationalError" in caplog.text + + assert mock_rollback.call_count == 1 + + async def test_query_recover_from_rollback( recorder_mock: Recorder, hass: HomeAssistant, diff --git a/tests/components/sql/test_services.py b/tests/components/sql/test_services.py index ad1fa202153422..39c01259f75bad 100644 --- a/tests/components/sql/test_services.py +++ b/tests/components/sql/test_services.py @@ -7,6 +7,7 @@ from unittest.mock import patch import pytest +from sqlalchemy import text import voluptuous as vol from voluptuous import MultipleInvalid @@ -86,6 +87,64 @@ async def test_query_service_external_db(hass: HomeAssistant, tmp_path: Path) -> } +@pytest.mark.parametrize( + ("async_driver", "patch_rollback"), + [ + ( + True, + "sqlalchemy.ext.asyncio.session.AsyncSession.rollback", + ), + ( + False, + "sqlalchemy.orm.session.Session.rollback", + ), + ], +) +async def test_query_service_rollback_on_error( + hass: HomeAssistant, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, + async_driver: bool, + patch_rollback: str, +) -> None: + """Test the query service.""" + db_path = tmp_path / "test.db" + db_url = f"sqlite{'+aiosqlite' if async_driver else ''}:///{db_path}" + + # Create and populate the external database + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE users (name TEXT, age INTEGER)") + conn.execute("INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)") + conn.commit() + conn.close() + + await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + with ( + patch( + "homeassistant.components.sql.services.generate_lambda_stmt", + return_value=text("Faulty syntax create operational issue"), + ), + pytest.raises( + ServiceValidationError, match="An error occurred when executing the query" + ), + patch(patch_rollback) as mock_session_rollback, + ): + await hass.services.async_call( + DOMAIN, + SERVICE_QUERY, + {"query": "SELECT name, age FROM users ORDER BY age", "db_url": db_url}, + blocking=True, + return_response=True, + ) + + assert "sqlite3.OperationalError" in caplog.text + assert mock_session_rollback.call_count == 1 + + await hass.async_stop() + + async def test_query_service_data_conversion( hass: HomeAssistant, tmp_path: Path ) -> None: @@ -189,7 +248,7 @@ async def test_query_service_invalid_db_url(hass: HomeAssistant) -> None: with ( patch( - "homeassistant.components.sql.util._validate_and_get_session_maker_for_db_url", + "homeassistant.components.sql.util._async_validate_and_get_session_maker_for_db_url", return_value=None, ), pytest.raises( diff --git a/tests/components/sql/test_util.py b/tests/components/sql/test_util.py index 737a5e4a41baac..f63626df6923bf 100644 --- a/tests/components/sql/test_util.py +++ b/tests/components/sql/test_util.py @@ -1,10 +1,17 @@ """Test the sql utils.""" +from datetime import date +from decimal import Decimal + import pytest import voluptuous as vol from homeassistant.components.recorder import Recorder, get_instance -from homeassistant.components.sql.util import resolve_db_url, validate_sql_select +from homeassistant.components.sql.util import ( + ensure_serializable, + resolve_db_url, + validate_sql_select, +) from homeassistant.core import HomeAssistant @@ -64,3 +71,22 @@ async def test_invalid_sql_queries( """Test that various invalid or disallowed SQL queries raise the correct exception.""" with pytest.raises(vol.Invalid, match=expected_error_message): validate_sql_select(sql_query) + + +@pytest.mark.parametrize( + ("input", "expected_output"), + [ + (Decimal("199.99"), 199.99), + (date(2023, 1, 15), "2023-01-15"), + (b"\xde\xad\xbe\xef", "0xdeadbeef"), + ("deadbeef", "deadbeef"), + (199.99, 199.99), + (69, 69), + ], +) +async def test_data_conversion( + input: Decimal | date | bytes | str | float, + expected_output: str | float, +) -> None: + """Test data conversion to serializable type.""" + assert ensure_serializable(input) == expected_output