Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions homeassistant/components/sql/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

from datetime import date
import decimal
import logging
from typing import Any

Expand Down Expand Up @@ -43,6 +41,7 @@
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .util import (
async_create_sessionmaker,
convert_value,
generate_lambda_stmt,
redact_credentials,
resolve_db_url,
Expand Down Expand Up @@ -253,7 +252,6 @@ async def async_update(self) -> None:
def _update(self) -> None:
"""Retrieve sensor data from the query."""
data = None
extra_state_attributes = {}
self._attr_extra_state_attributes = {}
sess: scoped_session = self.sessionmaker()
try:
Expand All @@ -272,14 +270,7 @@ def _update(self) -> None:
_LOGGER.debug("Query %s result in %s", self._query, res.items())
data = res[self._column_name]
for key, value in res.items():
if isinstance(value, decimal.Decimal):
value = float(value)
elif isinstance(value, date):
value = value.isoformat()
elif isinstance(value, (bytes, bytearray)):
value = f"0x{value.hex()}"
extra_state_attributes[key] = value
self._attr_extra_state_attributes[key] = value
self._attr_extra_state_attributes[key] = convert_value(value)

if data is not None and isinstance(data, (bytes, bytearray)):
data = f"0x{data.hex()}"
Expand Down
12 changes: 2 additions & 10 deletions homeassistant/components/sql/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

import datetime
import decimal
import logging

from sqlalchemy.engine import Result
Expand All @@ -26,6 +24,7 @@
from .const import CONF_QUERY, DOMAIN
from .util import (
async_create_sessionmaker,
convert_value,
generate_lambda_stmt,
redact_credentials,
resolve_db_url,
Expand Down Expand Up @@ -88,14 +87,7 @@ def _execute_and_convert_query() -> list[JsonValueType]:
for row in result.mappings():
processed_row: dict[str, JsonValueType] = {}
for key, value in row.items():
if isinstance(value, decimal.Decimal):
processed_row[key] = float(value)
elif isinstance(value, datetime.date):
processed_row[key] = value.isoformat()
elif isinstance(value, (bytes, bytearray)):
processed_row[key] = f"0x{value.hex()}"
else:
processed_row[key] = value
processed_row[key] = convert_value(value)
rows.append(processed_row)
return rows
finally:
Expand Down
16 changes: 16 additions & 0 deletions homeassistant/components/sql/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from __future__ import annotations

from datetime import date
from decimal import Decimal
import logging
from typing import Any

import sqlalchemy
from sqlalchemy import lambda_stmt
Expand Down Expand Up @@ -223,3 +226,16 @@ def generate_lambda_stmt(query: str) -> StatementLambdaElement:
"""Generate the lambda statement."""
text = sqlalchemy.text(query)
return lambda_stmt(lambda: text, lambda_cache=_SQL_LAMBDA_CACHE)


def convert_value(value: Any) -> Any:
"""Convert value."""
match value:
case Decimal():
return float(value)
case date():
return value.isoformat()
case bytes() | bytearray():
return f"0x{value.hex()}"
case _:
return value
29 changes: 28 additions & 1 deletion tests/components/sql/test_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
"""Test the sql utils."""

from datetime import UTC, date, datetime
from decimal import Decimal

import pytest
import voluptuous as vol

from homeassistant.components.recorder import Recorder, get_instance
from homeassistant.components.sql.util import resolve_db_url, validate_sql_select
from homeassistant.components.sql.util import (
convert_value,
resolve_db_url,
validate_sql_select,
)
from homeassistant.core import HomeAssistant


Expand Down Expand Up @@ -64,3 +71,23 @@ async def test_invalid_sql_queries(
"""Test that various invalid or disallowed SQL queries raise the correct exception."""
with pytest.raises(vol.Invalid, match=expected_error_message):
validate_sql_select(sql_query)


@pytest.mark.parametrize(
("input", "expected_output"),
[
(Decimal("199.99"), 199.99),
(date(2023, 1, 15), "2023-01-15"),
(datetime(2023, 1, 15, 12, 30, 45, tzinfo=UTC), "2023-01-15T12:30:45+00:00"),
(b"\xde\xad\xbe\xef", "0xdeadbeef"),
("deadbeef", "deadbeef"),
(199.99, 199.99),
(69, 69),
],
)
async def test_value_conversion(
input: Decimal | date | datetime | bytes | str | float,
expected_output: str | float,
) -> None:
"""Test value conversion."""
assert convert_value(input) == expected_output
Loading