Skip to content

Commit 144fc2a

Browse files
authored
Refactor SQL's data conversion (#155598)
1 parent c67e005 commit 144fc2a

File tree

4 files changed

+48
-22
lines changed

4 files changed

+48
-22
lines changed

homeassistant/components/sql/sensor.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from __future__ import annotations
44

5-
from datetime import date
6-
import decimal
75
import logging
86
from typing import Any
97

@@ -43,6 +41,7 @@
4341
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
4442
from .util import (
4543
async_create_sessionmaker,
44+
convert_value,
4645
generate_lambda_stmt,
4746
redact_credentials,
4847
resolve_db_url,
@@ -253,7 +252,6 @@ async def async_update(self) -> None:
253252
def _update(self) -> None:
254253
"""Retrieve sensor data from the query."""
255254
data = None
256-
extra_state_attributes = {}
257255
self._attr_extra_state_attributes = {}
258256
sess: scoped_session = self.sessionmaker()
259257
try:
@@ -272,14 +270,7 @@ def _update(self) -> None:
272270
_LOGGER.debug("Query %s result in %s", self._query, res.items())
273271
data = res[self._column_name]
274272
for key, value in res.items():
275-
if isinstance(value, decimal.Decimal):
276-
value = float(value)
277-
elif isinstance(value, date):
278-
value = value.isoformat()
279-
elif isinstance(value, (bytes, bytearray)):
280-
value = f"0x{value.hex()}"
281-
extra_state_attributes[key] = value
282-
self._attr_extra_state_attributes[key] = value
273+
self._attr_extra_state_attributes[key] = convert_value(value)
283274

284275
if data is not None and isinstance(data, (bytes, bytearray)):
285276
data = f"0x{data.hex()}"

homeassistant/components/sql/services.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from __future__ import annotations
44

5-
import datetime
6-
import decimal
75
import logging
86

97
from sqlalchemy.engine import Result
@@ -26,6 +24,7 @@
2624
from .const import CONF_QUERY, DOMAIN
2725
from .util import (
2826
async_create_sessionmaker,
27+
convert_value,
2928
generate_lambda_stmt,
3029
redact_credentials,
3130
resolve_db_url,
@@ -88,14 +87,7 @@ def _execute_and_convert_query() -> list[JsonValueType]:
8887
for row in result.mappings():
8988
processed_row: dict[str, JsonValueType] = {}
9089
for key, value in row.items():
91-
if isinstance(value, decimal.Decimal):
92-
processed_row[key] = float(value)
93-
elif isinstance(value, datetime.date):
94-
processed_row[key] = value.isoformat()
95-
elif isinstance(value, (bytes, bytearray)):
96-
processed_row[key] = f"0x{value.hex()}"
97-
else:
98-
processed_row[key] = value
90+
processed_row[key] = convert_value(value)
9991
rows.append(processed_row)
10092
return rows
10193
finally:

homeassistant/components/sql/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
from __future__ import annotations
44

5+
from datetime import date
6+
from decimal import Decimal
57
import logging
8+
from typing import Any
69

710
import sqlalchemy
811
from sqlalchemy import lambda_stmt
@@ -223,3 +226,16 @@ def generate_lambda_stmt(query: str) -> StatementLambdaElement:
223226
"""Generate the lambda statement."""
224227
text = sqlalchemy.text(query)
225228
return lambda_stmt(lambda: text, lambda_cache=_SQL_LAMBDA_CACHE)
229+
230+
231+
def convert_value(value: Any) -> Any:
232+
"""Convert value."""
233+
match value:
234+
case Decimal():
235+
return float(value)
236+
case date():
237+
return value.isoformat()
238+
case bytes() | bytearray():
239+
return f"0x{value.hex()}"
240+
case _:
241+
return value

tests/components/sql/test_util.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
"""Test the sql utils."""
22

3+
from datetime import UTC, date, datetime
4+
from decimal import Decimal
5+
36
import pytest
47
import voluptuous as vol
58

69
from homeassistant.components.recorder import Recorder, get_instance
7-
from homeassistant.components.sql.util import resolve_db_url, validate_sql_select
10+
from homeassistant.components.sql.util import (
11+
convert_value,
12+
resolve_db_url,
13+
validate_sql_select,
14+
)
815
from homeassistant.core import HomeAssistant
916

1017

@@ -64,3 +71,23 @@ async def test_invalid_sql_queries(
6471
"""Test that various invalid or disallowed SQL queries raise the correct exception."""
6572
with pytest.raises(vol.Invalid, match=expected_error_message):
6673
validate_sql_select(sql_query)
74+
75+
76+
@pytest.mark.parametrize(
77+
("input", "expected_output"),
78+
[
79+
(Decimal("199.99"), 199.99),
80+
(date(2023, 1, 15), "2023-01-15"),
81+
(datetime(2023, 1, 15, 12, 30, 45, tzinfo=UTC), "2023-01-15T12:30:45+00:00"),
82+
(b"\xde\xad\xbe\xef", "0xdeadbeef"),
83+
("deadbeef", "deadbeef"),
84+
(199.99, 199.99),
85+
(69, 69),
86+
],
87+
)
88+
async def test_value_conversion(
89+
input: Decimal | date | datetime | bytes | str | float,
90+
expected_output: str | float,
91+
) -> None:
92+
"""Test value conversion."""
93+
assert convert_value(input) == expected_output

0 commit comments

Comments
 (0)