Skip to content

Commit dfbaf66

Browse files
MindFreezecdce8pMartinHjelmare
authored
Add progress step decorator for easier config flows (home-assistant#152739)
Co-authored-by: Marc Mueller <[email protected]> Co-authored-by: Martin Hjelmare <[email protected]>
1 parent 62cea48 commit dfbaf66

File tree

3 files changed

+167
-70
lines changed

3 files changed

+167
-70
lines changed

homeassistant/components/homeassistant_hardware/firmware_config_flow.py

Lines changed: 30 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
OptionsFlow,
2929
)
3030
from homeassistant.core import callback
31-
from homeassistant.data_entry_flow import AbortFlow
31+
from homeassistant.data_entry_flow import AbortFlow, progress_step
3232
from homeassistant.exceptions import HomeAssistantError
3333
from homeassistant.helpers.aiohttp_client import async_get_clientsession
3434
from homeassistant.helpers.hassio import is_hassio
@@ -72,8 +72,6 @@ class BaseFirmwareInstallFlow(ConfigEntryBaseFlow, ABC):
7272
"""Base flow to install firmware."""
7373

7474
ZIGBEE_BAUDRATE = 115200 # Default, subclasses may override
75-
_failed_addon_name: str
76-
_failed_addon_reason: str
7775
_picked_firmware_type: PickedFirmwareType
7876

7977
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -85,8 +83,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
8583
self._hardware_name: str = "unknown" # To be set in a subclass
8684
self._zigbee_integration = ZigbeeIntegration.ZHA
8785

88-
self.addon_install_task: asyncio.Task | None = None
89-
self.addon_start_task: asyncio.Task | None = None
9086
self.addon_uninstall_task: asyncio.Task | None = None
9187
self.firmware_install_task: asyncio.Task[None] | None = None
9288
self.installing_firmware_name: str | None = None
@@ -486,18 +482,6 @@ async def async_step_install_zigbee_firmware(
486482
"""Install Zigbee firmware."""
487483
raise NotImplementedError
488484

489-
async def async_step_addon_operation_failed(
490-
self, user_input: dict[str, Any] | None = None
491-
) -> ConfigFlowResult:
492-
"""Abort when add-on installation or start failed."""
493-
return self.async_abort(
494-
reason=self._failed_addon_reason,
495-
description_placeholders={
496-
**self._get_translation_placeholders(),
497-
"addon_name": self._failed_addon_name,
498-
},
499-
)
500-
501485
async def async_step_pre_confirm_zigbee(
502486
self, user_input: dict[str, Any] | None = None
503487
) -> ConfigFlowResult:
@@ -561,6 +545,12 @@ async def async_step_install_thread_firmware(
561545
"""Install Thread firmware."""
562546
raise NotImplementedError
563547

548+
@progress_step(
549+
description_placeholders=lambda self: {
550+
**self._get_translation_placeholders(),
551+
"addon_name": get_otbr_addon_manager(self.hass).addon_name,
552+
}
553+
)
564554
async def async_step_install_otbr_addon(
565555
self, user_input: dict[str, Any] | None = None
566556
) -> ConfigFlowResult:
@@ -570,70 +560,43 @@ async def async_step_install_otbr_addon(
570560

571561
_LOGGER.debug("OTBR addon info: %s", addon_info)
572562

573-
if not self.addon_install_task:
574-
self.addon_install_task = self.hass.async_create_task(
575-
addon_manager.async_install_addon_waiting(),
576-
"OTBR addon install",
577-
)
578-
579-
if not self.addon_install_task.done():
580-
return self.async_show_progress(
581-
step_id="install_otbr_addon",
582-
progress_action="install_addon",
563+
try:
564+
await addon_manager.async_install_addon_waiting()
565+
except AddonError as err:
566+
_LOGGER.error(err)
567+
raise AbortFlow(
568+
"addon_install_failed",
583569
description_placeholders={
584570
**self._get_translation_placeholders(),
585571
"addon_name": addon_manager.addon_name,
586572
},
587-
progress_task=self.addon_install_task,
588-
)
589-
590-
try:
591-
await self.addon_install_task
592-
except AddonError as err:
593-
_LOGGER.error(err)
594-
self._failed_addon_name = addon_manager.addon_name
595-
self._failed_addon_reason = "addon_install_failed"
596-
return self.async_show_progress_done(next_step_id="addon_operation_failed")
597-
finally:
598-
self.addon_install_task = None
573+
) from err
599574

600-
return self.async_show_progress_done(next_step_id="finish_thread_installation")
575+
return await self.async_step_finish_thread_installation()
601576

577+
@progress_step(
578+
description_placeholders=lambda self: {
579+
**self._get_translation_placeholders(),
580+
"addon_name": get_otbr_addon_manager(self.hass).addon_name,
581+
}
582+
)
602583
async def async_step_start_otbr_addon(
603584
self, user_input: dict[str, Any] | None = None
604585
) -> ConfigFlowResult:
605586
"""Configure OTBR to point to the SkyConnect and run the addon."""
606-
otbr_manager = get_otbr_addon_manager(self.hass)
607-
608-
if not self.addon_start_task:
609-
self.addon_start_task = self.hass.async_create_task(
610-
self._configure_and_start_otbr_addon()
611-
)
612-
613-
if not self.addon_start_task.done():
614-
return self.async_show_progress(
615-
step_id="start_otbr_addon",
616-
progress_action="start_otbr_addon",
587+
try:
588+
await self._configure_and_start_otbr_addon()
589+
except AddonError as err:
590+
_LOGGER.error(err)
591+
raise AbortFlow(
592+
"addon_start_failed",
617593
description_placeholders={
618594
**self._get_translation_placeholders(),
619-
"addon_name": otbr_manager.addon_name,
595+
"addon_name": get_otbr_addon_manager(self.hass).addon_name,
620596
},
621-
progress_task=self.addon_start_task,
622-
)
623-
624-
try:
625-
await self.addon_start_task
626-
except (AddonError, AbortFlow) as err:
627-
_LOGGER.error(err)
628-
self._failed_addon_name = otbr_manager.addon_name
629-
self._failed_addon_reason = (
630-
err.reason if isinstance(err, AbortFlow) else "addon_start_failed"
631-
)
632-
return self.async_show_progress_done(next_step_id="addon_operation_failed")
633-
finally:
634-
self.addon_start_task = None
597+
) from err
635598

636-
return self.async_show_progress_done(next_step_id="pre_confirm_otbr")
599+
return await self.async_step_pre_confirm_otbr()
637600

638601
async def async_step_pre_confirm_otbr(
639602
self, user_input: dict[str, Any] | None = None

homeassistant/data_entry_flow.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import abc
66
import asyncio
77
from collections import defaultdict
8-
from collections.abc import Callable, Container, Hashable, Iterable, Mapping
8+
from collections.abc import Callable, Container, Coroutine, Hashable, Iterable, Mapping
99
from contextlib import suppress
1010
import copy
1111
from dataclasses import dataclass
1212
from enum import StrEnum
13+
import functools
1314
import logging
1415
from types import MappingProxyType
15-
from typing import Any, Generic, Required, TypedDict, TypeVar, cast
16+
from typing import Any, Concatenate, Generic, Required, TypedDict, TypeVar, cast
1617

1718
import voluptuous as vol
1819

@@ -150,6 +151,15 @@ class FlowResult(TypedDict, Generic[_FlowContextT, _HandlerT], total=False):
150151
url: str
151152

152153

154+
class ProgressStepData[_FlowResultT](TypedDict):
155+
"""Typed data for progress step tracking."""
156+
157+
tasks: dict[str, asyncio.Task[Any]]
158+
abort_reason: str
159+
abort_description_placeholders: Mapping[str, str]
160+
next_step_result: _FlowResultT | None
161+
162+
153163
def _map_error_to_schema_errors(
154164
schema_errors: dict[str, Any],
155165
error: vol.Invalid,
@@ -639,6 +649,12 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
639649
__progress_task: asyncio.Task[Any] | None = None
640650
__no_progress_task_reported = False
641651
deprecated_show_progress = False
652+
_progress_step_data: ProgressStepData[_FlowResultT] = {
653+
"tasks": {},
654+
"abort_reason": "",
655+
"abort_description_placeholders": MappingProxyType({}),
656+
"next_step_result": None,
657+
}
642658

643659
@property
644660
def source(self) -> str | None:
@@ -761,6 +777,37 @@ def async_abort(
761777
description_placeholders=description_placeholders,
762778
)
763779

780+
async def async_step__progress_step_abort(
781+
self, user_input: dict[str, Any] | None = None
782+
) -> _FlowResultT:
783+
"""Abort the flow."""
784+
return self.async_abort(
785+
reason=self._progress_step_data["abort_reason"],
786+
description_placeholders=self._progress_step_data[
787+
"abort_description_placeholders"
788+
],
789+
)
790+
791+
async def async_step__progress_step_progress_done(
792+
self, user_input: dict[str, Any] | None = None
793+
) -> _FlowResultT:
794+
"""Progress done. Return the next step.
795+
796+
Used by the progress_step decorator
797+
to allow decorated step methods
798+
to call the next step method, to change step,
799+
without using async_show_progress_done.
800+
If no next step is set, abort the flow.
801+
"""
802+
if self._progress_step_data["next_step_result"] is None:
803+
return self.async_abort(
804+
reason=self._progress_step_data["abort_reason"],
805+
description_placeholders=self._progress_step_data[
806+
"abort_description_placeholders"
807+
],
808+
)
809+
return self._progress_step_data["next_step_result"]
810+
764811
@callback
765812
def async_external_step(
766813
self,
@@ -930,3 +977,90 @@ def __init__(
930977
def __call__(self, value: Any) -> Any:
931978
"""Validate input."""
932979
return self.schema(value)
980+
981+
982+
type _FuncType[_T: FlowHandler[Any, Any, Any], _R: FlowResult[Any, Any], **_P] = (
983+
Callable[Concatenate[_T, _P], Coroutine[Any, Any, _R]]
984+
)
985+
986+
987+
def progress_step[
988+
HandlerT: FlowHandler[Any, Any, Any],
989+
ResultT: FlowResult[Any, Any],
990+
**P,
991+
](
992+
description_placeholders: (
993+
dict[str, str] | Callable[[Any], dict[str, str]] | None
994+
) = None,
995+
) -> Callable[[_FuncType[HandlerT, ResultT, P]], _FuncType[HandlerT, ResultT, P]]:
996+
"""Decorator to create a progress step from an async function.
997+
998+
The decorated method should be a step method
999+
which needs to show progress.
1000+
The method should accept dict[str, Any] as user_input
1001+
and should return a FlowResult or raise AbortFlow.
1002+
The method can call self.async_update_progress(progress)
1003+
to update progress.
1004+
1005+
Args:
1006+
description_placeholders: Static dict or callable that returns dict for progress UI placeholders.
1007+
"""
1008+
1009+
def decorator(
1010+
func: _FuncType[HandlerT, ResultT, P],
1011+
) -> _FuncType[HandlerT, ResultT, P]:
1012+
@functools.wraps(func)
1013+
async def wrapper(
1014+
self: FlowHandler[Any, ResultT], *args: P.args, **kwargs: P.kwargs
1015+
) -> ResultT:
1016+
step_id = func.__name__.replace("async_step_", "")
1017+
1018+
# Check if we have a progress task running
1019+
progress_task = self._progress_step_data["tasks"].get(step_id)
1020+
1021+
if progress_task is None:
1022+
# First call - create and start the progress task
1023+
progress_task = self.hass.async_create_task(
1024+
func(self, *args, **kwargs), # type: ignore[arg-type]
1025+
f"Progress step {step_id}",
1026+
)
1027+
self._progress_step_data["tasks"][step_id] = progress_task
1028+
1029+
if not progress_task.done():
1030+
# Handle description placeholders
1031+
placeholders = None
1032+
if description_placeholders is not None:
1033+
if callable(description_placeholders):
1034+
placeholders = description_placeholders(self)
1035+
else:
1036+
placeholders = description_placeholders
1037+
1038+
return self.async_show_progress(
1039+
step_id=step_id,
1040+
progress_action=step_id,
1041+
progress_task=progress_task,
1042+
description_placeholders=placeholders,
1043+
)
1044+
1045+
# Task is done or this is a subsequent call
1046+
try:
1047+
self._progress_step_data["next_step_result"] = await progress_task
1048+
except AbortFlow as err:
1049+
self._progress_step_data["abort_reason"] = err.reason
1050+
self._progress_step_data["abort_description_placeholders"] = (
1051+
err.description_placeholders or {}
1052+
)
1053+
return self.async_show_progress_done(
1054+
next_step_id="_progress_step_abort"
1055+
)
1056+
finally:
1057+
# Clean up task reference
1058+
self._progress_step_data["tasks"].pop(step_id, None)
1059+
1060+
return self.async_show_progress_done(
1061+
next_step_id="_progress_step_progress_done"
1062+
)
1063+
1064+
return wrapper
1065+
1066+
return decorator

tests/components/homeassistant_hardware/test_config_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ async def test_options_flow_zigbee_to_thread(
844844

845845
assert result["type"] is FlowResultType.SHOW_PROGRESS
846846
assert result["step_id"] == "install_otbr_addon"
847-
assert result["progress_action"] == "install_addon"
847+
assert result["progress_action"] == "install_otbr_addon"
848848

849849
await hass.async_block_till_done(wait_background_tasks=True)
850850

0 commit comments

Comments
 (0)