Skip to content

Commit 4c04dc0

Browse files
authored
Respect callback decorator in store helper async_delay_save (home-assistant#157158)
1 parent 0c36650 commit 4c04dc0

File tree

5 files changed

+136
-28
lines changed

5 files changed

+136
-28
lines changed

homeassistant/helpers/json.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,20 @@ def _orjson_bytes_default_encoder(data: Any) -> bytes:
160160
)
161161

162162

163-
def save_json(
164-
filename: str,
163+
def prepare_save_json(
165164
data: list | dict,
166-
private: bool = False,
167165
*,
168166
encoder: type[json.JSONEncoder] | None = None,
169-
atomic_writes: bool = False,
170-
) -> None:
171-
"""Save JSON data to a file."""
167+
) -> tuple[str, str | bytes]:
168+
"""Prepare JSON data for saving to a file.
169+
170+
Returns a tuple of (mode, json_data) where mode is either 'w' or 'wb'
171+
and json_data is either a str or bytes depending on the mode.
172+
173+
Args:
174+
data: Data to serialize.
175+
encoder: Optional custom JSON encoder.
176+
"""
172177
dump: Callable[[Any], Any]
173178
try:
174179
# For backwards compatibility, if they pass in the
@@ -188,10 +193,24 @@ def save_json(
188193
formatted_data = format_unserializable_data(
189194
find_paths_unserializable_data(data, dump=dump)
190195
)
191-
msg = f"Failed to serialize to JSON: {filename}. Bad data at {formatted_data}"
192-
_LOGGER.error(msg)
193-
raise SerializationError(msg) from error
196+
raise SerializationError(f"Bad data at {formatted_data}") from error
197+
return (mode, json_data)
198+
194199

200+
def save_json(
201+
filename: str,
202+
data: list | dict,
203+
private: bool = False,
204+
*,
205+
encoder: type[json.JSONEncoder] | None = None,
206+
atomic_writes: bool = False,
207+
) -> None:
208+
"""Save JSON data to a file."""
209+
try:
210+
mode, json_data = prepare_save_json(data, encoder=encoder)
211+
except SerializationError as err:
212+
_LOGGER.error("Failed to serialize to JSON: %s. %s", filename, err)
213+
raise
195214
method = write_utf8_file_atomic if atomic_writes else write_utf8_file
196215
method(filename, json_data, private, mode=mode)
197216

homeassistant/helpers/storage.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727
Event,
2828
HomeAssistant,
2929
callback,
30+
is_callback,
3031
)
3132
from homeassistant.exceptions import HomeAssistantError
3233
from homeassistant.loader import bind_hass
3334
from homeassistant.util import dt as dt_util, json as json_util
34-
from homeassistant.util.file import WriteError
35+
from homeassistant.util.file import WriteError, write_utf8_file, write_utf8_file_atomic
3536
from homeassistant.util.hass_dict import HassKey
3637

3738
from . import json as json_helper
@@ -441,7 +442,12 @@ def async_delay_save(
441442
data_func: Callable[[], _T],
442443
delay: float = 0,
443444
) -> None:
444-
"""Save data with an optional delay."""
445+
"""Save data with an optional delay.
446+
447+
data_func: A function that returns the data to save. If the function
448+
is decorated with @callback, it will be called in the event loop. If
449+
it is a regular function, it will be called from an executor.
450+
"""
445451
self._data = {
446452
"version": self.version,
447453
"minor_version": self.minor_version,
@@ -537,28 +543,37 @@ async def _async_handle_write_data(self, *_args):
537543
return
538544

539545
try:
540-
await self._async_write_data(self.path, data)
546+
await self._async_write_data(data)
541547
except (json_util.SerializationError, WriteError) as err:
542548
_LOGGER.error("Error writing config for %s: %s", self.key, err)
543549

544-
async def _async_write_data(self, path: str, data: dict) -> None:
545-
await self.hass.async_add_executor_job(self._write_data, self.path, data)
550+
async def _async_write_data(self, data: dict) -> None:
551+
if "data_func" in data and is_callback(data["data_func"]):
552+
data["data"] = data.pop("data_func")()
553+
mode, json_data = json_helper.prepare_save_json(data, encoder=self._encoder)
554+
await self.hass.async_add_executor_job(
555+
self._write_prepared_data, mode, json_data
556+
)
557+
return
558+
await self.hass.async_add_executor_job(self._write_data, data)
546559

547-
def _write_data(self, path: str, data: dict) -> None:
560+
def _write_data(self, data: dict) -> None:
548561
"""Write the data."""
549-
os.makedirs(os.path.dirname(path), exist_ok=True)
550-
551562
if "data_func" in data:
552563
data["data"] = data.pop("data_func")()
564+
mode, json_data = json_helper.prepare_save_json(data, encoder=self._encoder)
565+
self._write_prepared_data(mode, json_data)
566+
567+
def _write_prepared_data(self, mode: str, json_data: str | bytes) -> None:
568+
"""Write the data."""
569+
path = self.path
570+
os.makedirs(os.path.dirname(path), exist_ok=True)
553571

554572
_LOGGER.debug("Writing data for %s to %s", self.key, path)
555-
json_helper.save_json(
556-
path,
557-
data,
558-
self._private,
559-
encoder=self._encoder,
560-
atomic_writes=self._atomic_writes,
573+
write_method = (
574+
write_utf8_file_atomic if self._atomic_writes else write_utf8_file
561575
)
576+
write_method(path, json_data, self._private, mode=mode)
562577

563578
async def _async_migrate_func(self, old_major_version, old_minor_version, old_data):
564579
"""Migrate to the new version."""

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1531,7 +1531,7 @@ async def mock_async_load(
15311531
return loaded
15321532

15331533
async def mock_write_data(
1534-
store: storage.Store, path: str, data_to_write: dict[str, Any]
1534+
store: storage.Store, data_to_write: dict[str, Any]
15351535
) -> None:
15361536
"""Mock version of write data."""
15371537
# To ensure that the data can be serialized

tests/helpers/test_json.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,7 @@ class CannotSerializeMe:
237237
with pytest.raises(SerializationError) as excinfo:
238238
save_json("test4", {"hello": CannotSerializeMe()})
239239

240-
assert "Failed to serialize to JSON: test4. Bad data at $.hello=" in str(
241-
excinfo.value
242-
)
240+
assert "Bad data at $.hello=" in str(excinfo.value)
243241

244242

245243
def test_custom_encoder(tmp_path: Path) -> None:

tests/helpers/test_storage.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from datetime import timedelta
55
import json
66
import os
7+
from pathlib import Path
8+
import threading
79
from typing import Any, NamedTuple
810
from unittest.mock import Mock, patch
911

@@ -17,7 +19,12 @@
1719
EVENT_HOMEASSISTANT_STARTED,
1820
EVENT_HOMEASSISTANT_STOP,
1921
)
20-
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, CoreState, HomeAssistant
22+
from homeassistant.core import (
23+
DOMAIN as HOMEASSISTANT_DOMAIN,
24+
CoreState,
25+
HomeAssistant,
26+
callback,
27+
)
2128
from homeassistant.exceptions import HomeAssistantError
2229
from homeassistant.helpers import issue_registry as ir, storage
2330
from homeassistant.helpers.json import json_bytes
@@ -35,6 +42,7 @@
3542
MOCK_MINOR_VERSION_1 = 1
3643
MOCK_MINOR_VERSION_2 = 2
3744
MOCK_KEY = "storage-test"
45+
MOCK_KEY2 = "storage-test-2"
3846
MOCK_DATA = {"hello": "world"}
3947
MOCK_DATA2 = {"goodbye": "cruel world"}
4048

@@ -140,6 +148,74 @@ async def test_saving_with_delay(
140148
}
141149

142150

151+
async def test_saving_with_delay_threading(tmp_path: Path) -> None:
152+
"""Test thread handling when saving with a delay."""
153+
calls = []
154+
155+
async def assert_storage_data(store_key: str, expected_data: str) -> None:
156+
"""Assert storage data."""
157+
158+
def read_storage_data(store_key: str) -> str:
159+
"""Read storage data."""
160+
return Path(tmp_path / f".storage/{store_key}").read_text(encoding="utf-8")
161+
162+
store_data = await asyncio.to_thread(read_storage_data, store_key)
163+
assert store_data == expected_data
164+
165+
async with async_test_home_assistant(config_dir=tmp_path) as hass:
166+
167+
def data_producer_thread_safe() -> Any:
168+
"""Produce data to store."""
169+
assert threading.get_ident() != hass.loop_thread_id
170+
calls.append("thread_safe")
171+
return MOCK_DATA
172+
173+
@callback
174+
def data_producer_callback() -> Any:
175+
"""Produce data to store."""
176+
assert threading.get_ident() == hass.loop_thread_id
177+
calls.append("callback")
178+
return MOCK_DATA2
179+
180+
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
181+
store.async_delay_save(data_producer_thread_safe, 1)
182+
183+
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=1))
184+
await hass.async_block_till_done()
185+
186+
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY2)
187+
store.async_delay_save(data_producer_callback, 1)
188+
189+
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=1))
190+
await hass.async_block_till_done()
191+
192+
assert calls == ["thread_safe", "callback"]
193+
expected_data = (
194+
"{\n"
195+
' "version": 1,\n'
196+
' "minor_version": 1,\n'
197+
' "key": "storage-test",\n'
198+
' "data": {\n'
199+
' "hello": "world"\n'
200+
" }\n"
201+
"}"
202+
)
203+
await assert_storage_data(MOCK_KEY, expected_data)
204+
expected_data = (
205+
"{\n"
206+
' "version": 1,\n'
207+
' "minor_version": 1,\n'
208+
' "key": "storage-test-2",\n'
209+
' "data": {\n'
210+
' "goodbye": "cruel world"\n'
211+
" }\n"
212+
"}"
213+
)
214+
await assert_storage_data(MOCK_KEY2, expected_data)
215+
216+
await hass.async_stop(force=True)
217+
218+
143219
async def test_saving_with_delay_churn_reduction(
144220
hass: HomeAssistant,
145221
store: storage.Store,

0 commit comments

Comments
 (0)