Skip to content
Open

PR #6

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion homeassistant/components/sql/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,11 @@
"config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/sql",
"iot_class": "local_polling",
"requirements": ["SQLAlchemy==2.0.41", "sqlparse==0.5.0"]
"requirements": [
"aiomysql==0.3.2",
"aiosqlite==0.21.0",
"asyncpg==0.30.0",
"SQLAlchemy[asyncio]==2.0.41",
"sqlparse==0.5.0"
]
}
7 changes: 5 additions & 2 deletions homeassistant/components/sql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from dataclasses import dataclass

from sqlalchemy.orm import scoped_session
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import Session, scoped_session

from homeassistant.core import CALLBACK_TYPE

Expand All @@ -14,4 +15,6 @@ class SQLData:
"""Data for the sql integration."""

shutdown_event_cancel: CALLBACK_TYPE
session_makers_by_db_url: dict[str, scoped_session]
session_makers_by_db_url: dict[
str, async_scoped_session[AsyncSession] | scoped_session[Session]
]
92 changes: 50 additions & 42 deletions homeassistant/components/sql/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

from __future__ import annotations

from datetime import date
import decimal
import logging
from typing import Any
from typing import TYPE_CHECKING, Any

from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import Session, scoped_session

from homeassistant.components.recorder import CONF_DB_URL, get_instance
from homeassistant.components.sensor import CONF_STATE_CLASS
Expand Down Expand Up @@ -43,6 +42,7 @@
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .util import (
async_create_sessionmaker,
ensure_serializable,
generate_lambda_stmt,
redact_credentials,
resolve_db_url,
Expand Down Expand Up @@ -200,7 +200,7 @@ class SQLSensor(ManualTriggerSensorEntity):
def __init__(
self,
trigger_entity_config: ConfigType,
sessmaker: scoped_session,
sessmaker: async_scoped_session[AsyncSession] | scoped_session[Session],
query: str,
column: str,
value_template: ValueTemplate | None,
Expand Down Expand Up @@ -243,43 +243,16 @@ def extra_state_attributes(self) -> dict[str, Any] | None:
"""Return extra attributes."""
return dict(self._attr_extra_state_attributes)

async def async_update(self) -> None:
"""Retrieve sensor data from the query using the right executor."""
if self._use_database_executor:
await get_instance(self.hass).async_add_executor_job(self._update)
else:
await self.hass.async_add_executor_job(self._update)

def _update(self) -> None:
"""Retrieve sensor data from the query."""
def _process(self, result: Result) -> None:
"""Process the SQL result."""
data = None
extra_state_attributes = {}
self._attr_extra_state_attributes = {}
sess: scoped_session = self.sessionmaker()
try:
result: Result = sess.execute(self._lambda_stmt)
except SQLAlchemyError as err:
_LOGGER.error(
"Error executing query %s: %s",
self._query,
redact_credentials(str(err)),
)
sess.rollback()
sess.close()
return

for res in result.mappings():
_LOGGER.debug("Query %s result in %s", self._query, res.items())
data = res[self._column_name]
for key, value in res.items():
if isinstance(value, decimal.Decimal):
value = float(value)
elif isinstance(value, date):
value = value.isoformat()
elif isinstance(value, (bytes, bytearray)):
value = f"0x{value.hex()}"
extra_state_attributes[key] = value
self._attr_extra_state_attributes[key] = value

for row in result.mappings():
row_items = row.items()
_LOGGER.debug("Query %s result in %s", self._query, row_items)
data = row[self._column_name]
for key, value in row_items:
self._attr_extra_state_attributes[key] = ensure_serializable(value)

if data is not None and isinstance(data, (bytes, bytearray)):
data = f"0x{data.hex()}"
Expand All @@ -298,4 +271,39 @@ def _update(self) -> None:
if data is None:
_LOGGER.warning("%s returned no results", self._query)

sess.close()
def _update(self) -> None:
"""Retrieve sensor data from the query.

This does I/O and should be run in the executor.
"""
if TYPE_CHECKING:
assert isinstance(self.sessionmaker, scoped_session)
with self.sessionmaker() as session:
try:
self._process(session.execute(self._lambda_stmt))
except SQLAlchemyError as err:
_LOGGER.error(
"Error executing query %s: %s",
self._query,
redact_credentials(str(err)),
)
session.rollback()

async def async_update(self) -> None:
"""Retrieve sensor data from the query using the right executor."""
self._attr_extra_state_attributes = {}
if isinstance(self.sessionmaker, async_scoped_session):
async with self.sessionmaker() as session:
try:
self._process(await session.execute(self._lambda_stmt))
except SQLAlchemyError as err:
_LOGGER.error(
"Error executing query %s: %s",
self._query,
redact_credentials(str(err)),
)
await session.rollback()
elif self._use_database_executor:
await get_instance(self.hass).async_add_executor_job(self._update)
else:
await self.hass.async_add_executor_job(self._update)
73 changes: 41 additions & 32 deletions homeassistant/components/sql/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from __future__ import annotations

import datetime
import decimal
import logging
from typing import TYPE_CHECKING

from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import async_scoped_session
from sqlalchemy.orm import scoped_session
import voluptuous as vol

from homeassistant.components.recorder import CONF_DB_URL, get_instance
Expand All @@ -26,6 +26,7 @@
from .const import CONF_QUERY, DOMAIN
from .util import (
async_create_sessionmaker,
ensure_serializable,
generate_lambda_stmt,
redact_credentials,
resolve_db_url,
Expand Down Expand Up @@ -70,39 +71,47 @@ async def _async_query_service(
translation_placeholders={"db_url": redact_credentials(db_url)},
)

def _process(result: Result) -> list[JsonValueType]:
rows: list[JsonValueType] = []
for row in result.mappings():
processed_row: dict[str, JsonValueType] = {}
for key, value in row.items():
processed_row[key] = ensure_serializable(value)
rows.append(processed_row)
return rows

def _execute_and_convert_query() -> list[JsonValueType]:
"""Execute the query and return the results with converted types."""
sess: Session = sessmaker()
try:
result: Result = sess.execute(generate_lambda_stmt(query_str))
except SQLAlchemyError as err:
_LOGGER.debug(
"Error executing query %s: %s",
query_str,
redact_credentials(str(err)),
)
sess.rollback()
raise
else:
rows: list[JsonValueType] = []
for row in result.mappings():
processed_row: dict[str, JsonValueType] = {}
for key, value in row.items():
if isinstance(value, decimal.Decimal):
processed_row[key] = float(value)
elif isinstance(value, datetime.date):
processed_row[key] = value.isoformat()
elif isinstance(value, (bytes, bytearray)):
processed_row[key] = f"0x{value.hex()}"
else:
processed_row[key] = value
rows.append(processed_row)
return rows
finally:
sess.close()
if TYPE_CHECKING:
assert isinstance(sessmaker, scoped_session)
with sessmaker() as session:
try:
return _process(session.execute(generate_lambda_stmt(query_str)))
except SQLAlchemyError as err:
_LOGGER.debug(
"Error executing query %s: %s",
query_str,
redact_credentials(str(err)),
)
session.rollback()
raise

try:
if use_database_executor:
if isinstance(sessmaker, async_scoped_session):
async with sessmaker() as session:
try:
result = _process(
await session.execute(generate_lambda_stmt(query_str))
)
except SQLAlchemyError as err:
_LOGGER.debug(
"Error executing query %s: %s",
query_str,
redact_credentials(str(err)),
)
await session.rollback()
raise
elif use_database_executor:
result = await get_instance(call.hass).async_add_executor_job(
_execute_and_convert_query
)
Expand Down
Loading
Loading